Lint cleanup: ruff fixes across entire codebase

- Strip trailing whitespace from all Python files
- Fix import sorting (I001) across all modules
- Convert Optional[X] to X | None syntax (UP045)
- Remove unused imports (F401)
- Convert lambda assignments to def functions (E731)
- Add TYPE_CHECKING import for forward references
- Update pyproject.toml ruff config:
  - Move select/ignore to [tool.ruff.lint] section
  - Add per-file ignores for DCT colorspace naming (N803/N806)
  - Add per-file ignores for __init__.py import structure (E402)
  - Exclude defunct test_routes.py
- Remove frontends/web/test_routes.py (defunct debug snippet)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee
2026-01-02 17:17:38 -05:00
parent d94ee7be90
commit 6b21190f97
36 changed files with 2275 additions and 2383 deletions

View File

@@ -51,30 +51,30 @@ print("Testing scipy DCT...")
try: try:
from scipy.fftpack import dct, idct from scipy.fftpack import dct, idct
import numpy as np import numpy as np
# Create test array # Create test array
test = np.random.rand(8, 8).astype(np.float64) test = np.random.rand(8, 8).astype(np.float64)
print(f"Input array shape: {test.shape}, dtype: {test.dtype}") print(f"Input array shape: {test.shape}, dtype: {test.dtype}")
# Test 1D DCT # Test 1D DCT
row = test[0, :] row = test[0, :]
result = dct(row, norm='ortho') result = dct(row, norm='ortho')
print(f"1D DCT result shape: {result.shape}, dtype: {result.dtype}") print(f"1D DCT result shape: {result.shape}, dtype: {result.dtype}")
# Test 2D DCT (the potentially problematic operation) # Test 2D DCT (the potentially problematic operation)
result2d = dct(dct(test.T, norm='ortho').T, norm='ortho') result2d = dct(dct(test.T, norm='ortho').T, norm='ortho')
print(f"2D DCT result shape: {result2d.shape}, dtype: {result2d.dtype}") print(f"2D DCT result shape: {result2d.shape}, dtype: {result2d.dtype}")
# Test inverse # Test inverse
recovered = idct(idct(result2d.T, norm='ortho').T, norm='ortho') recovered = idct(idct(result2d.T, norm='ortho').T, norm='ortho')
error = np.max(np.abs(test - recovered)) error = np.max(np.abs(test - recovered))
print(f"Round-trip error: {error}") print(f"Round-trip error: {error}")
if error < 1e-10: if error < 1e-10:
print("✓ scipy DCT working correctly") print("✓ scipy DCT working correctly")
else: else:
print("⚠ scipy DCT has precision issues") print("⚠ scipy DCT has precision issues")
except Exception as e: except Exception as e:
print(f"✗ scipy DCT failed: {e}") print(f"✗ scipy DCT failed: {e}")
import traceback import traceback
@@ -90,11 +90,11 @@ try:
from scipy.fftpack import dct, idct from scipy.fftpack import dct, idct
import numpy as np import numpy as np
import gc import gc
# Simulate processing many 8x8 blocks # Simulate processing many 8x8 blocks
large_array = np.random.rand(512, 512).astype(np.float64) large_array = np.random.rand(512, 512).astype(np.float64)
print(f"Large array shape: {large_array.shape}, size: {large_array.nbytes} bytes") print(f"Large array shape: {large_array.shape}, size: {large_array.nbytes} bytes")
count = 0 count = 0
for y in range(0, 512, 8): for y in range(0, 512, 8):
for x in range(0, 512, 8): for x in range(0, 512, 8):
@@ -103,14 +103,14 @@ try:
recovered = idct(idct(dct_block.T, norm='ortho').T, norm='ortho') recovered = idct(idct(dct_block.T, norm='ortho').T, norm='ortho')
large_array[y:y+8, x:x+8] = recovered large_array[y:y+8, x:x+8] = recovered
count += 1 count += 1
print(f"Processed {count} blocks successfully") print(f"Processed {count} blocks successfully")
del large_array del large_array
gc.collect() gc.collect()
print("✓ Large array processing completed") print("✓ Large array processing completed")
except Exception as e: except Exception as e:
print(f"✗ Large array processing failed: {e}") print(f"✗ Large array processing failed: {e}")
import traceback import traceback
@@ -125,26 +125,26 @@ print("Testing PIL with large image...")
try: try:
from PIL import Image from PIL import Image
import io import io
# Create a large test image # Create a large test image
img = Image.new('RGB', (4000, 3000), color=(128, 128, 128)) img = Image.new('RGB', (4000, 3000), color=(128, 128, 128))
# Save to bytes # Save to bytes
buffer = io.BytesIO() buffer = io.BytesIO()
img.save(buffer, format='PNG') img.save(buffer, format='PNG')
img_bytes = buffer.getvalue() img_bytes = buffer.getvalue()
print(f"Test image size: {len(img_bytes)} bytes") print(f"Test image size: {len(img_bytes)} bytes")
# Re-open and process # Re-open and process
buffer2 = io.BytesIO(img_bytes) buffer2 = io.BytesIO(img_bytes)
img2 = Image.open(buffer2) img2 = Image.open(buffer2)
print(f"Re-opened image: {img2.size}, mode: {img2.mode}") print(f"Re-opened image: {img2.size}, mode: {img2.mode}")
# Convert to numpy array # Convert to numpy array
import numpy as np import numpy as np
arr = np.array(img2) arr = np.array(img2)
print(f"NumPy array: {arr.shape}, dtype: {arr.dtype}") print(f"NumPy array: {arr.shape}, dtype: {arr.dtype}")
# Clean up # Clean up
img.close() img.close()
img2.close() img2.close()
@@ -152,9 +152,9 @@ try:
buffer2.close() buffer2.close()
del arr del arr
gc.collect() gc.collect()
print("✓ PIL large image test completed") print("✓ PIL large image test completed")
except Exception as e: except Exception as e:
print(f"✗ PIL test failed: {e}") print(f"✗ PIL test failed: {e}")
import traceback import traceback

View File

@@ -69,13 +69,13 @@ def main():
print("\nOptional: add passphrase, pin, key path") print("\nOptional: add passphrase, pin, key path")
print(" python debug_jpegio.py stego.jpg ref.jpg 'passphrase' '123456' key.pem") print(" python debug_jpegio.py stego.jpg ref.jpg 'passphrase' '123456' key.pem")
sys.exit(1) sys.exit(1)
stego_path = sys.argv[1] stego_path = sys.argv[1]
ref_path = sys.argv[2] ref_path = sys.argv[2]
passphrase = sys.argv[3] if len(sys.argv) > 3 else "test" passphrase = sys.argv[3] if len(sys.argv) > 3 else "test"
pin = sys.argv[4] if len(sys.argv) > 4 else "" pin = sys.argv[4] if len(sys.argv) > 4 else ""
key_path = sys.argv[5] if len(sys.argv) > 5 else None key_path = sys.argv[5] if len(sys.argv) > 5 else None
print(f"\n{'='*60}") print(f"\n{'='*60}")
print("JPEGIO DCT EXTRACTION DEBUG") print("JPEGIO DCT EXTRACTION DEBUG")
print(f"{'='*60}") print(f"{'='*60}")
@@ -84,7 +84,7 @@ def main():
print(f"Passphrase: '{passphrase}'") print(f"Passphrase: '{passphrase}'")
print(f"PIN: '{pin}'") print(f"PIN: '{pin}'")
print(f"Key: {key_path}") print(f"Key: {key_path}")
# Load stego image with jpegio # Load stego image with jpegio
print(f"\n[1] Loading stego image with jpegio...") print(f"\n[1] Loading stego image with jpegio...")
try: try:
@@ -96,7 +96,7 @@ def main():
except Exception as e: except Exception as e:
print(f" ✗ Failed: {e}") print(f" ✗ Failed: {e}")
sys.exit(1) sys.exit(1)
# Get coefficient array (channel 0) # Get coefficient array (channel 0)
coef_array = jpeg.coef_arrays[0] coef_array = jpeg.coef_arrays[0]
print(f"\n[2] Coefficient array analysis...") print(f"\n[2] Coefficient array analysis...")
@@ -104,21 +104,21 @@ def main():
print(f" Non-zero coefficients: {np.count_nonzero(coef_array)}") print(f" Non-zero coefficients: {np.count_nonzero(coef_array)}")
print(f" Min value: {coef_array.min()}") print(f" Min value: {coef_array.min()}")
print(f" Max value: {coef_array.max()}") print(f" Max value: {coef_array.max()}")
# Get usable positions # Get usable positions
print(f"\n[3] Finding usable positions (|coef| >= 2, non-DC)...") print(f"\n[3] Finding usable positions (|coef| >= 2, non-DC)...")
positions = get_usable_positions(coef_array) positions = get_usable_positions(coef_array)
print(f" Usable positions: {len(positions)}") print(f" Usable positions: {len(positions)}")
print(f" Capacity: ~{len(positions) // 8} bytes") print(f" Capacity: ~{len(positions) // 8} bytes")
# Generate seed (this needs to match the encode seed!) # Generate seed (this needs to match the encode seed!)
print(f"\n[4] Generating seed...") print(f"\n[4] Generating seed...")
# Load reference photo # Load reference photo
ref_data = Path(ref_path).read_bytes() ref_data = Path(ref_path).read_bytes()
ref_hash = hashlib.sha256(ref_data).digest() ref_hash = hashlib.sha256(ref_data).digest()
print(f" Reference hash: {ref_hash[:8].hex()}...") print(f" Reference hash: {ref_hash[:8].hex()}...")
# Load RSA key if provided # Load RSA key if provided
rsa_component = b"" rsa_component = b""
if key_path: if key_path:
@@ -130,7 +130,7 @@ def main():
rsa_key = load_rsa_key(key_data, password=None) rsa_key = load_rsa_key(key_data, password=None)
except: except:
rsa_key = load_rsa_key(key_data, password="testpass") rsa_key = load_rsa_key(key_data, password="testpass")
# Get public key bytes for seed # Get public key bytes for seed
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
pub_bytes = rsa_key.public_key().public_bytes( pub_bytes = rsa_key.public_key().public_bytes(
@@ -141,7 +141,7 @@ def main():
print(f" RSA key loaded, hash: {rsa_component[:8].hex()}...") print(f" RSA key loaded, hash: {rsa_component[:8].hex()}...")
except Exception as e: except Exception as e:
print(f" ✗ Could not load RSA key: {e}") print(f" ✗ Could not load RSA key: {e}")
# Build seed like stegasoo does # Build seed like stegasoo does
# This is the critical part - must match encoding! # This is the critical part - must match encoding!
seed_parts = [ seed_parts = [
@@ -152,12 +152,12 @@ def main():
] ]
seed = hashlib.sha256(b"".join(seed_parts)).digest() seed = hashlib.sha256(b"".join(seed_parts)).digest()
print(f" Combined seed: {seed[:8].hex()}...") print(f" Combined seed: {seed[:8].hex()}...")
# Generate order # Generate order
print(f"\n[5] Generating coefficient order...") print(f"\n[5] Generating coefficient order...")
order = generate_order(len(positions), seed) order = generate_order(len(positions), seed)
print(f" First 10 indices: {order[:10]}") print(f" First 10 indices: {order[:10]}")
# Try to extract header # Try to extract header
print(f"\n[6] Extracting header (first 80 bits = 10 bytes)...") print(f"\n[6] Extracting header (first 80 bits = 10 bytes)...")
HEADER_SIZE = 10 HEADER_SIZE = 10
@@ -165,7 +165,7 @@ def main():
header_bytes = bits_to_bytes(header_bits) header_bytes = bits_to_bytes(header_bits)
print(f" Raw header bytes: {header_bytes.hex()}") print(f" Raw header bytes: {header_bytes.hex()}")
print(f" As ASCII (if printable): {repr(header_bytes)}") print(f" As ASCII (if printable): {repr(header_bytes)}")
# Check for JPGS magic # Check for JPGS magic
JPEGIO_MAGIC = b'JPGS' JPEGIO_MAGIC = b'JPGS'
if header_bytes[:4] == JPEGIO_MAGIC: if header_bytes[:4] == JPEGIO_MAGIC:
@@ -176,7 +176,7 @@ def main():
print(f" Version: {version}") print(f" Version: {version}")
print(f" Flags: {flags}") print(f" Flags: {flags}")
print(f" Data length: {data_length} bytes") print(f" Data length: {data_length} bytes")
if data_length > 0 and data_length < len(positions) // 8: if data_length > 0 and data_length < len(positions) // 8:
print(f"\n[7] Extracting payload ({data_length} bytes)...") print(f"\n[7] Extracting payload ({data_length} bytes)...")
total_bits = (HEADER_SIZE + data_length) * 8 total_bits = (HEADER_SIZE + data_length) * 8
@@ -191,10 +191,10 @@ def main():
print(f" ✗ No JPEGIO magic found") print(f" ✗ No JPEGIO magic found")
print(f" Expected: {JPEGIO_MAGIC.hex()} ('JPGS')") print(f" Expected: {JPEGIO_MAGIC.hex()} ('JPGS')")
print(f" Got: {header_bytes[:4].hex()} ('{header_bytes[:4]}')") print(f" Got: {header_bytes[:4].hex()} ('{header_bytes[:4]}')")
# Try alternate interpretations # Try alternate interpretations
print(f"\n[7] Trying alternate header interpretations...") print(f"\n[7] Trying alternate header interpretations...")
# Maybe it's scipy DCT format? # Maybe it's scipy DCT format?
DCT_MAGIC = b'DCTS' DCT_MAGIC = b'DCTS'
if header_bytes[:4] == DCT_MAGIC: if header_bytes[:4] == DCT_MAGIC:
@@ -202,7 +202,7 @@ def main():
else: else:
# Show bit distribution # Show bit distribution
print(f" First 32 extracted bits: {header_bits[:32]}") print(f" First 32 extracted bits: {header_bits[:32]}")
# Check if bits look random or patterned # Check if bits look random or patterned
ones = sum(header_bits[:80]) ones = sum(header_bits[:80])
print(f" Bit distribution: {ones}/80 ones ({100*ones/80:.1f}%)") print(f" Bit distribution: {ones}/80 ones ({100*ones/80:.1f}%)")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -22,20 +22,16 @@ NEW in v3.0.1: DCT output format selection (PNG or JPEG) and color mode (graysca
""" """
import io import io
import mimetypes
import os
import secrets
import sys import sys
import time import time
import secrets
import mimetypes
from pathlib import Path from pathlib import Path
from datetime import datetime
from flask import Flask, flash, jsonify, redirect, render_template, request, send_file, url_for
from PIL import Image from PIL import Image
from flask import (
Flask, render_template, request, send_file,
jsonify, flash, redirect, url_for
)
import os
os.environ['NUMPY_MADVISE_HUGEPAGE'] = '0' os.environ['NUMPY_MADVISE_HUGEPAGE'] = '0'
os.environ['OMP_NUM_THREADS'] = '1' os.environ['OMP_NUM_THREADS'] = '1'
@@ -44,75 +40,76 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src'))
import stegasoo import stegasoo
from stegasoo import ( from stegasoo import (
generate_credentials, CapacityError,
export_rsa_key_pem, load_rsa_key, DecryptionError,
validate_pin, validate_message, validate_image,
validate_rsa_key, validate_security_factors,
validate_file_payload, validate_passphrase,
generate_filename,
StegasooError, DecryptionError, CapacityError,
has_argon2,
FilePayload, FilePayload,
# Embedding modes StegasooError,
EMBED_MODE_LSB, export_rsa_key_pem,
EMBED_MODE_DCT, generate_credentials,
EMBED_MODE_AUTO, generate_filename,
has_dct_support,
# Channel key functions (v4.0.0)
has_channel_key,
get_channel_status, get_channel_status,
has_argon2,
# Channel key functions (v4.0.0)
has_dct_support,
load_rsa_key,
validate_channel_key, validate_channel_key,
generate_channel_key, validate_file_payload,
# NOTE: encode, decode, compare_modes, will_fit_by_mode now use subprocess isolation validate_image,
validate_message,
validate_passphrase,
validate_pin,
validate_rsa_key,
validate_security_factors,
) )
from stegasoo.constants import ( from stegasoo.constants import (
__version__,
MAX_MESSAGE_SIZE, MAX_MESSAGE_CHARS,
MIN_PIN_LENGTH, MAX_PIN_LENGTH,
MIN_PASSPHRASE_WORDS, RECOMMENDED_PASSPHRASE_WORDS,
DEFAULT_PASSPHRASE_WORDS, DEFAULT_PASSPHRASE_WORDS,
VALID_RSA_SIZES, MAX_FILE_SIZE, MAX_FILE_PAYLOAD_SIZE,
MAX_FILE_PAYLOAD_SIZE, MAX_UPLOAD_SIZE, MAX_FILE_SIZE,
TEMP_FILE_EXPIRY, TEMP_FILE_EXPIRY_MINUTES, MAX_MESSAGE_CHARS,
THUMBNAIL_SIZE, THUMBNAIL_QUALITY, MAX_PIN_LENGTH,
MAX_UPLOAD_SIZE,
MIN_PASSPHRASE_WORDS,
MIN_PIN_LENGTH,
RECOMMENDED_PASSPHRASE_WORDS,
TEMP_FILE_EXPIRY,
TEMP_FILE_EXPIRY_MINUTES,
THUMBNAIL_QUALITY,
THUMBNAIL_SIZE,
VALID_RSA_SIZES,
__version__,
) )
# QR Code support # QR Code support
try: try:
import qrcode import qrcode # noqa: F401
from qrcode.constants import ERROR_CORRECT_L, ERROR_CORRECT_M from qrcode.constants import ERROR_CORRECT_L, ERROR_CORRECT_M # noqa: F401
HAS_QRCODE = True HAS_QRCODE = True
except ImportError: except ImportError:
HAS_QRCODE = False HAS_QRCODE = False
# QR Code reading # QR Code reading
try: try:
from pyzbar.pyzbar import decode as pyzbar_decode from pyzbar.pyzbar import decode as pyzbar_decode # noqa: F401
HAS_QRCODE_READ = True HAS_QRCODE_READ = True
except ImportError: except ImportError:
HAS_QRCODE_READ = False HAS_QRCODE_READ = False
import zlib
import base64
# Import QR utilities # Import QR utilities
from stegasoo.qr_utils import (
compress_data, decompress_data, auto_decompress,
is_compressed, can_fit_in_qr, needs_compression,
generate_qr_code, read_qr_code, extract_key_from_qr,
detect_and_crop_qr,
has_qr_write, has_qr_read,
QR_MAX_BINARY, COMPRESSION_PREFIX
)
# ============================================================================ # ============================================================================
# SUBPROCESS ISOLATION FOR STEGASOO OPERATIONS # SUBPROCESS ISOLATION FOR STEGASOO OPERATIONS
# ============================================================================ # ============================================================================
# Runs encode/decode/compare in subprocesses to prevent jpegio/scipy crashes # Runs encode/decode/compare in subprocesses to prevent jpegio/scipy crashes
# from taking down the Flask server. # from taking down the Flask server.
from subprocess_stego import SubprocessStego from subprocess_stego import SubprocessStego
from stegasoo.qr_utils import (
can_fit_in_qr,
detect_and_crop_qr,
extract_key_from_qr,
generate_qr_code,
)
# Initialize subprocess wrapper (worker script must be in same directory) # Initialize subprocess wrapper (worker script must be in same directory)
subprocess_stego = SubprocessStego(timeout=180) # 3 minute timeout for large images subprocess_stego = SubprocessStego(timeout=180) # 3 minute timeout for large images
@@ -139,7 +136,7 @@ def inject_globals():
"""Inject global variables into all templates.""" """Inject global variables into all templates."""
# Get channel status (v4.0.0) # Get channel status (v4.0.0)
channel_status = get_channel_status() channel_status = get_channel_status()
return { return {
'version': __version__, 'version': __version__,
'max_message_chars': MAX_MESSAGE_CHARS, 'max_message_chars': MAX_MESSAGE_CHARS,
@@ -172,20 +169,20 @@ try:
print(f"Current MAX_FILE_PAYLOAD_SIZE: {MAX_FILE_PAYLOAD_SIZE}") print(f"Current MAX_FILE_PAYLOAD_SIZE: {MAX_FILE_PAYLOAD_SIZE}")
print(f"DCT support: {has_dct_support()}") print(f"DCT support: {has_dct_support()}")
print(f"QR code support: write={HAS_QRCODE}, read={HAS_QRCODE_READ}") print(f"QR code support: write={HAS_QRCODE}, read={HAS_QRCODE_READ}")
# Channel key status (v4.0.0) # Channel key status (v4.0.0)
channel_status = get_channel_status() channel_status = get_channel_status()
print(f"Channel key: {channel_status['mode']} mode") print(f"Channel key: {channel_status['mode']} mode")
if channel_status['configured']: if channel_status['configured']:
print(f" Fingerprint: {channel_status.get('fingerprint')}") print(f" Fingerprint: {channel_status.get('fingerprint')}")
print(f" Source: {channel_status.get('source')}") print(f" Source: {channel_status.get('source')}")
DESIRED_PAYLOAD_SIZE = 2 * 1024 * 1024 # 2MB DESIRED_PAYLOAD_SIZE = 2 * 1024 * 1024 # 2MB
if hasattr(stegasoo, 'MAX_FILE_PAYLOAD_SIZE'): if hasattr(stegasoo, 'MAX_FILE_PAYLOAD_SIZE'):
print(f"Overriding MAX_FILE_PAYLOAD_SIZE to {DESIRED_PAYLOAD_SIZE}") print(f"Overriding MAX_FILE_PAYLOAD_SIZE to {DESIRED_PAYLOAD_SIZE}")
stegasoo.MAX_FILE_PAYLOAD_SIZE = DESIRED_PAYLOAD_SIZE stegasoo.MAX_FILE_PAYLOAD_SIZE = DESIRED_PAYLOAD_SIZE
except Exception as e: except Exception as e:
print(f"Could not override stegasoo limits: {e}") print(f"Could not override stegasoo limits: {e}")
@@ -197,10 +194,10 @@ except Exception as e:
def resolve_channel_key_form(channel_key_value: str) -> str: def resolve_channel_key_form(channel_key_value: str) -> str:
""" """
Resolve channel key from form input. Resolve channel key from form input.
Args: Args:
channel_key_value: Form value ('auto', 'none', or explicit key) channel_key_value: Form value ('auto', 'none', or explicit key)
Returns: Returns:
Value to pass to subprocess_stego ('auto', 'none', or explicit key) Value to pass to subprocess_stego ('auto', 'none', or explicit key)
""" """
@@ -234,10 +231,10 @@ def generate_thumbnail(image_data: bytes, size: tuple = THUMBNAIL_SIZE) -> bytes
img = img.convert('RGB') img = img.convert('RGB')
elif img.mode != 'RGB': elif img.mode != 'RGB':
img = img.convert('RGB') img = img.convert('RGB')
# Create thumbnail # Create thumbnail
img.thumbnail(size, Image.Resampling.LANCZOS) img.thumbnail(size, Image.Resampling.LANCZOS)
# Save to bytes # Save to bytes
buffer = io.BytesIO() buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=THUMBNAIL_QUALITY, optimize=True) img.save(buffer, format='JPEG', quality=THUMBNAIL_QUALITY, optimize=True)
@@ -251,7 +248,7 @@ def cleanup_temp_files():
"""Remove expired temporary files.""" """Remove expired temporary files."""
now = time.time() now = time.time()
expired = [fid for fid, info in TEMP_FILES.items() if now - info['timestamp'] > TEMP_FILE_EXPIRY] expired = [fid for fid, info in TEMP_FILES.items() if now - info['timestamp'] > TEMP_FILE_EXPIRY]
for fid in expired: for fid in expired:
TEMP_FILES.pop(fid, None) TEMP_FILES.pop(fid, None)
# Also clean up corresponding thumbnail # Also clean up corresponding thumbnail
@@ -294,12 +291,12 @@ def index():
def api_channel_status(): def api_channel_status():
""" """
Get current channel key status (v4.0.0). Get current channel key status (v4.0.0).
Returns JSON with mode, fingerprint, and source. Returns JSON with mode, fingerprint, and source.
""" """
# Use subprocess for isolation # Use subprocess for isolation
result = subprocess_stego.get_channel_status(reveal=False) result = subprocess_stego.get_channel_status(reveal=False)
if result.success: if result.success:
return jsonify({ return jsonify({
'success': True, 'success': True,
@@ -324,16 +321,16 @@ def api_channel_status():
def api_channel_validate(): def api_channel_validate():
""" """
Validate a channel key format (v4.0.0). Validate a channel key format (v4.0.0).
Returns JSON with validation result. Returns JSON with validation result.
""" """
key = request.form.get('key', '') or request.json.get('key', '') if request.is_json else '' key = request.form.get('key', '') or request.json.get('key', '') if request.is_json else ''
if not key: if not key:
return jsonify({'valid': False, 'error': 'No key provided'}) return jsonify({'valid': False, 'error': 'No key provided'})
is_valid = validate_channel_key(key) is_valid = validate_channel_key(key)
if is_valid: if is_valid:
fingerprint = f"{key[:4]}-••••-••••-••••-••••-••••-••••-{key[-4:]}" fingerprint = f"{key[:4]}-••••-••••-••••-••••-••••-••••-{key[-4:]}"
return jsonify({ return jsonify({
@@ -358,20 +355,20 @@ def generate():
words_per_passphrase = int(request.form.get('words_per_passphrase', DEFAULT_PASSPHRASE_WORDS)) words_per_passphrase = int(request.form.get('words_per_passphrase', DEFAULT_PASSPHRASE_WORDS))
use_pin = request.form.get('use_pin') == 'on' use_pin = request.form.get('use_pin') == 'on'
use_rsa = request.form.get('use_rsa') == 'on' use_rsa = request.form.get('use_rsa') == 'on'
if not use_pin and not use_rsa: if not use_pin and not use_rsa:
flash('You must select at least one security factor (PIN or RSA Key)', 'error') flash('You must select at least one security factor (PIN or RSA Key)', 'error')
return render_template('generate.html', generated=False, has_qrcode=HAS_QRCODE) return render_template('generate.html', generated=False, has_qrcode=HAS_QRCODE)
pin_length = int(request.form.get('pin_length', 6)) pin_length = int(request.form.get('pin_length', 6))
rsa_bits = int(request.form.get('rsa_bits', 2048)) rsa_bits = int(request.form.get('rsa_bits', 2048))
# Clamp values # Clamp values
words_per_passphrase = max(MIN_PASSPHRASE_WORDS, min(12, words_per_passphrase)) words_per_passphrase = max(MIN_PASSPHRASE_WORDS, min(12, words_per_passphrase))
pin_length = max(MIN_PIN_LENGTH, min(MAX_PIN_LENGTH, pin_length)) pin_length = max(MIN_PIN_LENGTH, min(MAX_PIN_LENGTH, pin_length))
if rsa_bits not in VALID_RSA_SIZES: if rsa_bits not in VALID_RSA_SIZES:
rsa_bits = 2048 rsa_bits = 2048
try: try:
# v3.2.0 FIX: Use correct parameter name 'passphrase_words' # v3.2.0 FIX: Use correct parameter name 'passphrase_words'
creds = generate_credentials( creds = generate_credentials(
@@ -381,19 +378,19 @@ def generate():
rsa_bits=rsa_bits, rsa_bits=rsa_bits,
passphrase_words=words_per_passphrase, # FIX: was words_per_passphrase= passphrase_words=words_per_passphrase, # FIX: was words_per_passphrase=
) )
# Store RSA key temporarily for QR generation # Store RSA key temporarily for QR generation
qr_token = None qr_token = None
qr_needs_compression = False qr_needs_compression = False
qr_too_large = False qr_too_large = False
if creds.rsa_key_pem and HAS_QRCODE: if creds.rsa_key_pem and HAS_QRCODE:
# Check if key fits in QR code # Check if key fits in QR code
if can_fit_in_qr(creds.rsa_key_pem, compress=True): if can_fit_in_qr(creds.rsa_key_pem, compress=True):
qr_needs_compression = True qr_needs_compression = True
else: else:
qr_too_large = True qr_too_large = True
if not qr_too_large: if not qr_too_large:
qr_token = secrets.token_urlsafe(16) qr_token = secrets.token_urlsafe(16)
cleanup_temp_files() cleanup_temp_files()
@@ -404,7 +401,7 @@ def generate():
'type': 'rsa_key', 'type': 'rsa_key',
'compress': qr_needs_compression 'compress': qr_needs_compression
} }
# v3.2.0: Single passphrase instead of daily phrases # v3.2.0: Single passphrase instead of daily phrases
return render_template('generate.html', return render_template('generate.html',
passphrase=creds.passphrase, # v3.2.0: Single passphrase passphrase=creds.passphrase, # v3.2.0: Single passphrase
@@ -428,7 +425,7 @@ def generate():
except Exception as e: except Exception as e:
flash(f'Error generating credentials: {e}', 'error') flash(f'Error generating credentials: {e}', 'error')
return render_template('generate.html', generated=False, has_qrcode=HAS_QRCODE) return render_template('generate.html', generated=False, has_qrcode=HAS_QRCODE)
return render_template('generate.html', generated=False, has_qrcode=HAS_QRCODE) return render_template('generate.html', generated=False, has_qrcode=HAS_QRCODE)
@@ -437,19 +434,19 @@ def generate_qr(token):
"""Generate QR code for RSA key.""" """Generate QR code for RSA key."""
if not HAS_QRCODE: if not HAS_QRCODE:
return "QR code support not available", 501 return "QR code support not available", 501
if token not in TEMP_FILES: if token not in TEMP_FILES:
return "Token expired or invalid", 404 return "Token expired or invalid", 404
file_info = TEMP_FILES[token] file_info = TEMP_FILES[token]
if file_info.get('type') != 'rsa_key': if file_info.get('type') != 'rsa_key':
return "Invalid token type", 400 return "Invalid token type", 400
try: try:
key_pem = file_info['data'].decode('utf-8') key_pem = file_info['data'].decode('utf-8')
compress = file_info.get('compress', False) compress = file_info.get('compress', False)
qr_png = generate_qr_code(key_pem, compress=compress) qr_png = generate_qr_code(key_pem, compress=compress)
return send_file( return send_file(
io.BytesIO(qr_png), io.BytesIO(qr_png),
mimetype='image/png', mimetype='image/png',
@@ -464,19 +461,19 @@ def generate_qr_download(token):
"""Download QR code as PNG file.""" """Download QR code as PNG file."""
if not HAS_QRCODE: if not HAS_QRCODE:
return "QR code support not available", 501 return "QR code support not available", 501
if token not in TEMP_FILES: if token not in TEMP_FILES:
return "Token expired or invalid", 404 return "Token expired or invalid", 404
file_info = TEMP_FILES[token] file_info = TEMP_FILES[token]
if file_info.get('type') != 'rsa_key': if file_info.get('type') != 'rsa_key':
return "Invalid token type", 400 return "Invalid token type", 400
try: try:
key_pem = file_info['data'].decode('utf-8') key_pem = file_info['data'].decode('utf-8')
compress = file_info.get('compress', False) compress = file_info.get('compress', False)
qr_png = generate_qr_code(key_pem, compress=compress) qr_png = generate_qr_code(key_pem, compress=compress)
return send_file( return send_file(
io.BytesIO(qr_png), io.BytesIO(qr_png),
mimetype='image/png', mimetype='image/png',
@@ -491,29 +488,29 @@ def generate_qr_download(token):
def qr_crop(): def qr_crop():
""" """
Detect and crop QR code from an image. Detect and crop QR code from an image.
Useful for extracting QR codes from photos taken at an angle, Useful for extracting QR codes from photos taken at an angle,
with extra background, etc. Returns the cropped QR as PNG. with extra background, etc. Returns the cropped QR as PNG.
""" """
if not HAS_QRCODE_READ: if not HAS_QRCODE_READ:
return jsonify({'error': 'QR code reading not available (install pyzbar)'}), 501 return jsonify({'error': 'QR code reading not available (install pyzbar)'}), 501
image_file = request.files.get('image') image_file = request.files.get('image')
if not image_file: if not image_file:
return jsonify({'error': 'No image provided'}), 400 return jsonify({'error': 'No image provided'}), 400
try: try:
image_data = image_file.read() image_data = image_file.read()
# Use the new crop function # Use the new crop function
cropped = detect_and_crop_qr(image_data) cropped = detect_and_crop_qr(image_data)
if cropped is None: if cropped is None:
return jsonify({'error': 'No QR code detected in image'}), 404 return jsonify({'error': 'No QR code detected in image'}), 404
# Return as downloadable PNG or inline based on query param # Return as downloadable PNG or inline based on query param
as_attachment = request.args.get('download', '').lower() in ('1', 'true', 'yes') as_attachment = request.args.get('download', '').lower() in ('1', 'true', 'yes')
return send_file( return send_file(
io.BytesIO(cropped), io.BytesIO(cropped),
mimetype='image/png', mimetype='image/png',
@@ -567,18 +564,18 @@ def extract_key_from_qr_route():
'success': False, 'success': False,
'error': 'QR code reading not available. Install pyzbar and libzbar.' 'error': 'QR code reading not available. Install pyzbar and libzbar.'
}), 501 }), 501
qr_image = request.files.get('qr_image') qr_image = request.files.get('qr_image')
if not qr_image: if not qr_image:
return jsonify({ return jsonify({
'success': False, 'success': False,
'error': 'No QR image provided' 'error': 'No QR image provided'
}), 400 }), 400
try: try:
image_data = qr_image.read() image_data = qr_image.read()
key_pem = extract_key_from_qr(image_data) key_pem = extract_key_from_qr(image_data)
if key_pem: if key_pem:
return jsonify({ return jsonify({
'success': True, 'success': True,
@@ -589,7 +586,7 @@ def extract_key_from_qr_route():
'success': False, 'success': False,
'error': 'No valid RSA key found in QR code' 'error': 'No valid RSA key found in QR code'
}), 400 }), 400
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, 'success': False,
@@ -611,16 +608,16 @@ def api_compare_capacity():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier image provided'}), 400 return jsonify({'error': 'No carrier image provided'}), 400
try: try:
carrier_data = carrier.read() carrier_data = carrier.read()
# Use subprocess-isolated compare_modes # Use subprocess-isolated compare_modes
result = subprocess_stego.compare_modes(carrier_data) result = subprocess_stego.compare_modes(carrier_data)
if not result.success: if not result.success:
return jsonify({'error': result.error or 'Comparison failed'}), 500 return jsonify({'error': result.error or 'Comparison failed'}), 500
return jsonify({ return jsonify({
'success': True, 'success': True,
'width': result.width, 'width': result.width,
@@ -652,29 +649,29 @@ def api_check_fit():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
payload_size = request.form.get('payload_size', type=int) payload_size = request.form.get('payload_size', type=int)
embed_mode = request.form.get('embed_mode', 'lsb') embed_mode = request.form.get('embed_mode', 'lsb')
if not carrier or payload_size is None: if not carrier or payload_size is None:
return jsonify({'error': 'Missing carrier or payload_size'}), 400 return jsonify({'error': 'Missing carrier or payload_size'}), 400
if embed_mode not in ('lsb', 'dct'): if embed_mode not in ('lsb', 'dct'):
return jsonify({'error': 'Invalid embed_mode'}), 400 return jsonify({'error': 'Invalid embed_mode'}), 400
if embed_mode == 'dct' and not has_dct_support(): if embed_mode == 'dct' and not has_dct_support():
return jsonify({'error': 'DCT mode requires scipy'}), 400 return jsonify({'error': 'DCT mode requires scipy'}), 400
try: try:
carrier_data = carrier.read() carrier_data = carrier.read()
# Use subprocess-isolated capacity check # Use subprocess-isolated capacity check
result = subprocess_stego.check_capacity( result = subprocess_stego.check_capacity(
carrier_data=carrier_data, carrier_data=carrier_data,
payload_size=payload_size, payload_size=payload_size,
embed_mode=embed_mode, embed_mode=embed_mode,
) )
if not result.success: if not result.success:
return jsonify({'error': result.error or 'Capacity check failed'}), 500 return jsonify({'error': result.error or 'Capacity check failed'}), 500
return jsonify({ return jsonify({
'success': True, 'success': True,
'fits': result.fits, 'fits': result.fits,
@@ -701,55 +698,55 @@ def encode_page():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
rsa_key_file = request.files.get('rsa_key') rsa_key_file = request.files.get('rsa_key')
payload_file = request.files.get('payload_file') payload_file = request.files.get('payload_file')
if not ref_photo or not carrier: if not ref_photo or not carrier:
flash('Both reference photo and carrier image are required', 'error') flash('Both reference photo and carrier image are required', 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
if not allowed_image(ref_photo.filename) or not allowed_image(carrier.filename): if not allowed_image(ref_photo.filename) or not allowed_image(carrier.filename):
flash('Invalid file type. Use PNG, JPG, or BMP', 'error') flash('Invalid file type. Use PNG, JPG, or BMP', 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# Get form data - v3.2.0: renamed from day_phrase to passphrase # Get form data - v3.2.0: renamed from day_phrase to passphrase
message = request.form.get('message', '') message = request.form.get('message', '')
passphrase = request.form.get('passphrase', '') # v3.2.0: Renamed passphrase = request.form.get('passphrase', '') # v3.2.0: Renamed
pin = request.form.get('pin', '').strip() pin = request.form.get('pin', '').strip()
rsa_password = request.form.get('rsa_password', '') rsa_password = request.form.get('rsa_password', '')
payload_type = request.form.get('payload_type', 'text') payload_type = request.form.get('payload_type', 'text')
# NEW in v3.0 - Embedding mode # NEW in v3.0 - Embedding mode
embed_mode = request.form.get('embed_mode', 'lsb') embed_mode = request.form.get('embed_mode', 'lsb')
if embed_mode not in ('lsb', 'dct'): if embed_mode not in ('lsb', 'dct'):
embed_mode = 'lsb' embed_mode = 'lsb'
# NEW in v3.0.1 - DCT output format # NEW in v3.0.1 - DCT output format
dct_output_format = request.form.get('dct_output_format', 'png') dct_output_format = request.form.get('dct_output_format', 'png')
if dct_output_format not in ('png', 'jpeg'): if dct_output_format not in ('png', 'jpeg'):
dct_output_format = 'png' dct_output_format = 'png'
# NEW in v3.0.1 - DCT color mode # NEW in v3.0.1 - DCT color mode
dct_color_mode = request.form.get('dct_color_mode', 'color') dct_color_mode = request.form.get('dct_color_mode', 'color')
if dct_color_mode not in ('grayscale', 'color'): if dct_color_mode not in ('grayscale', 'color'):
dct_color_mode = 'color' dct_color_mode = 'color'
# NEW in v4.0.0 - Channel key # NEW in v4.0.0 - Channel key
channel_key = resolve_channel_key_form(request.form.get('channel_key', 'auto')) channel_key = resolve_channel_key_form(request.form.get('channel_key', 'auto'))
# Check DCT availability # Check DCT availability
if embed_mode == 'dct' and not has_dct_support(): if embed_mode == 'dct' and not has_dct_support():
flash('DCT mode requires scipy. Install with: pip install scipy', 'error') flash('DCT mode requires scipy. Install with: pip install scipy', 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# Determine payload # Determine payload
if payload_type == 'file' and payload_file and payload_file.filename: if payload_type == 'file' and payload_file and payload_file.filename:
# File payload # File payload
file_data = payload_file.read() file_data = payload_file.read()
result = validate_file_payload(file_data, payload_file.filename) result = validate_file_payload(file_data, payload_file.filename)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
mime_type, _ = mimetypes.guess_type(payload_file.filename) mime_type, _ = mimetypes.guess_type(payload_file.filename)
payload = FilePayload( payload = FilePayload(
data=file_data, data=file_data,
@@ -763,31 +760,31 @@ def encode_page():
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
payload = message payload = message
# v3.2.0: Renamed from day_phrase # v3.2.0: Renamed from day_phrase
if not passphrase: if not passphrase:
flash('Passphrase is required', 'error') flash('Passphrase is required', 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# v3.2.0: Validate passphrase # v3.2.0: Validate passphrase
result = validate_passphrase(passphrase) result = validate_passphrase(passphrase)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# Show warning if passphrase is short # Show warning if passphrase is short
if result.warning: if result.warning:
flash(result.warning, 'warning') flash(result.warning, 'warning')
# Read files # Read files
ref_data = ref_photo.read() ref_data = ref_photo.read()
carrier_data = carrier.read() carrier_data = carrier.read()
# Handle RSA key - can come from .pem file or QR code image # Handle RSA key - can come from .pem file or QR code image
rsa_key_data = None rsa_key_data = None
rsa_key_qr = request.files.get('rsa_key_qr') rsa_key_qr = request.files.get('rsa_key_qr')
rsa_key_from_qr = False rsa_key_from_qr = False
if rsa_key_file and rsa_key_file.filename: if rsa_key_file and rsa_key_file.filename:
rsa_key_data = rsa_key_file.read() rsa_key_data = rsa_key_file.read()
elif rsa_key_qr and rsa_key_qr.filename and HAS_QRCODE_READ: elif rsa_key_qr and rsa_key_qr.filename and HAS_QRCODE_READ:
@@ -799,36 +796,36 @@ def encode_page():
else: else:
flash('Could not extract RSA key from QR code image.', 'error') flash('Could not extract RSA key from QR code image.', 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# Validate security factors # Validate security factors
result = validate_security_factors(pin, rsa_key_data) result = validate_security_factors(pin, rsa_key_data)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# Validate PIN if provided # Validate PIN if provided
if pin: if pin:
result = validate_pin(pin) result = validate_pin(pin)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# Determine key password # Determine key password
key_password = None if rsa_key_from_qr else (rsa_password if rsa_password else None) key_password = None if rsa_key_from_qr else (rsa_password if rsa_password else None)
# Validate RSA key if provided # Validate RSA key if provided
if rsa_key_data: if rsa_key_data:
result = validate_rsa_key(rsa_key_data, key_password) result = validate_rsa_key(rsa_key_data, key_password)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# Validate carrier image # Validate carrier image
result = validate_image(carrier_data, "Carrier image") result = validate_image(carrier_data, "Carrier image")
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
# v4.0.0: Include channel_key parameter # v4.0.0: Include channel_key parameter
# Use subprocess-isolated encode to prevent crashes # Use subprocess-isolated encode to prevent crashes
if payload_type == 'file' and payload_file and payload_file.filename: if payload_type == 'file' and payload_file and payload_file.filename:
@@ -861,14 +858,14 @@ def encode_page():
dct_color_mode=dct_color_mode if embed_mode == 'dct' else 'color', dct_color_mode=dct_color_mode if embed_mode == 'dct' else 'color',
channel_key=channel_key, # v4.0.0 channel_key=channel_key, # v4.0.0
) )
# Check for subprocess errors # Check for subprocess errors
if not encode_result.success: if not encode_result.success:
error_msg = encode_result.error or 'Encoding failed' error_msg = encode_result.error or 'Encoding failed'
if 'capacity' in error_msg.lower(): if 'capacity' in error_msg.lower():
raise CapacityError(error_msg) raise CapacityError(error_msg)
raise StegasooError(error_msg) raise StegasooError(error_msg)
# Determine actual output format for filename and storage # Determine actual output format for filename and storage
if embed_mode == 'dct' and dct_output_format == 'jpeg': if embed_mode == 'dct' and dct_output_format == 'jpeg':
output_ext = '.jpg' output_ext = '.jpg'
@@ -876,14 +873,14 @@ def encode_page():
else: else:
output_ext = '.png' output_ext = '.png'
output_mime = 'image/png' output_mime = 'image/png'
# Use filename from result or generate one # Use filename from result or generate one
filename = encode_result.filename filename = encode_result.filename
if not filename: if not filename:
filename = generate_filename('stego', output_ext) filename = generate_filename('stego', output_ext)
elif embed_mode == 'dct' and dct_output_format == 'jpeg' and filename.endswith('.png'): elif embed_mode == 'dct' and dct_output_format == 'jpeg' and filename.endswith('.png'):
filename = filename[:-4] + '.jpg' filename = filename[:-4] + '.jpg'
# Store temporarily # Store temporarily
file_id = secrets.token_urlsafe(16) file_id = secrets.token_urlsafe(16)
cleanup_temp_files() cleanup_temp_files()
@@ -899,9 +896,9 @@ def encode_page():
'channel_mode': encode_result.channel_mode, 'channel_mode': encode_result.channel_mode,
'channel_fingerprint': encode_result.channel_fingerprint, 'channel_fingerprint': encode_result.channel_fingerprint,
} }
return redirect(url_for('encode_result', file_id=file_id)) return redirect(url_for('encode_result', file_id=file_id))
except CapacityError as e: except CapacityError as e:
flash(str(e), 'error') flash(str(e), 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
@@ -911,7 +908,7 @@ def encode_page():
except Exception as e: except Exception as e:
flash(f'Error: {e}', 'error') flash(f'Error: {e}', 'error')
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('encode.html', has_qrcode_read=HAS_QRCODE_READ)
@@ -920,17 +917,17 @@ def encode_result(file_id):
if file_id not in TEMP_FILES: if file_id not in TEMP_FILES:
flash('File expired or not found. Please encode again.', 'error') flash('File expired or not found. Please encode again.', 'error')
return redirect(url_for('encode_page')) return redirect(url_for('encode_page'))
file_info = TEMP_FILES[file_id] file_info = TEMP_FILES[file_id]
# Generate thumbnail # Generate thumbnail
thumbnail_data = generate_thumbnail(file_info['data']) thumbnail_data = generate_thumbnail(file_info['data'])
thumbnail_id = None thumbnail_id = None
if thumbnail_data: if thumbnail_data:
thumbnail_id = f"{file_id}_thumb" thumbnail_id = f"{file_id}_thumb"
THUMBNAIL_FILES[thumbnail_id] = thumbnail_data THUMBNAIL_FILES[thumbnail_id] = thumbnail_data
return render_template('encode_result.html', return render_template('encode_result.html',
file_id=file_id, file_id=file_id,
filename=file_info['filename'], filename=file_info['filename'],
@@ -949,7 +946,7 @@ def encode_thumbnail(thumb_id):
"""Serve thumbnail image.""" """Serve thumbnail image."""
if thumb_id not in THUMBNAIL_FILES: if thumb_id not in THUMBNAIL_FILES:
return "Thumbnail not found", 404 return "Thumbnail not found", 404
return send_file( return send_file(
io.BytesIO(THUMBNAIL_FILES[thumb_id]), io.BytesIO(THUMBNAIL_FILES[thumb_id]),
mimetype='image/jpeg', mimetype='image/jpeg',
@@ -962,10 +959,10 @@ def encode_download(file_id):
if file_id not in TEMP_FILES: if file_id not in TEMP_FILES:
flash('File expired or not found.', 'error') flash('File expired or not found.', 'error')
return redirect(url_for('encode_page')) return redirect(url_for('encode_page'))
file_info = TEMP_FILES[file_id] file_info = TEMP_FILES[file_id]
mime_type = file_info.get('mime_type', 'image/png') mime_type = file_info.get('mime_type', 'image/png')
return send_file( return send_file(
io.BytesIO(file_info['data']), io.BytesIO(file_info['data']),
mimetype=mime_type, mimetype=mime_type,
@@ -979,10 +976,10 @@ def encode_file_route(file_id):
"""Serve file for Web Share API.""" """Serve file for Web Share API."""
if file_id not in TEMP_FILES: if file_id not in TEMP_FILES:
return "Not found", 404 return "Not found", 404
file_info = TEMP_FILES[file_id] file_info = TEMP_FILES[file_id]
mime_type = file_info.get('mime_type', 'image/png') mime_type = file_info.get('mime_type', 'image/png')
return send_file( return send_file(
io.BytesIO(file_info['data']), io.BytesIO(file_info['data']),
mimetype=mime_type, mimetype=mime_type,
@@ -995,11 +992,11 @@ def encode_file_route(file_id):
def encode_cleanup(file_id): def encode_cleanup(file_id):
"""Manually cleanup a file after sharing.""" """Manually cleanup a file after sharing."""
TEMP_FILES.pop(file_id, None) TEMP_FILES.pop(file_id, None)
# Also cleanup thumbnail if exists # Also cleanup thumbnail if exists
thumb_id = f"{file_id}_thumb" thumb_id = f"{file_id}_thumb"
THUMBNAIL_FILES.pop(thumb_id, None) THUMBNAIL_FILES.pop(thumb_id, None)
return jsonify({'status': 'ok'}) return jsonify({'status': 'ok'})
@@ -1015,45 +1012,45 @@ def decode_page():
ref_photo = request.files.get('reference_photo') ref_photo = request.files.get('reference_photo')
stego_image = request.files.get('stego_image') stego_image = request.files.get('stego_image')
rsa_key_file = request.files.get('rsa_key') rsa_key_file = request.files.get('rsa_key')
if not ref_photo or not stego_image: if not ref_photo or not stego_image:
flash('Both reference photo and stego image are required', 'error') flash('Both reference photo and stego image are required', 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
# Get form data - v3.2.0: renamed from day_phrase to passphrase # Get form data - v3.2.0: renamed from day_phrase to passphrase
passphrase = request.form.get('passphrase', '') # v3.2.0: Renamed passphrase = request.form.get('passphrase', '') # v3.2.0: Renamed
pin = request.form.get('pin', '').strip() pin = request.form.get('pin', '').strip()
rsa_password = request.form.get('rsa_password', '') rsa_password = request.form.get('rsa_password', '')
# NEW in v3.0 - Extraction mode # NEW in v3.0 - Extraction mode
embed_mode = request.form.get('embed_mode', 'auto') embed_mode = request.form.get('embed_mode', 'auto')
if embed_mode not in ('auto', 'lsb', 'dct'): if embed_mode not in ('auto', 'lsb', 'dct'):
embed_mode = 'auto' embed_mode = 'auto'
# NEW in v4.0.0 - Channel key # NEW in v4.0.0 - Channel key
channel_key = resolve_channel_key_form(request.form.get('channel_key', 'auto')) channel_key = resolve_channel_key_form(request.form.get('channel_key', 'auto'))
# Check DCT availability # Check DCT availability
if embed_mode == 'dct' and not has_dct_support(): if embed_mode == 'dct' and not has_dct_support():
flash('DCT mode requires scipy. Install with: pip install scipy', 'error') flash('DCT mode requires scipy. Install with: pip install scipy', 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
# v3.2.0: Removed date handling (no stego_date needed) # v3.2.0: Removed date handling (no stego_date needed)
# v3.2.0: Renamed from day_phrase # v3.2.0: Renamed from day_phrase
if not passphrase: if not passphrase:
flash('Passphrase is required', 'error') flash('Passphrase is required', 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
# Read files # Read files
ref_data = ref_photo.read() ref_data = ref_photo.read()
stego_data = stego_image.read() stego_data = stego_image.read()
# Handle RSA key - can come from .pem file or QR code image # Handle RSA key - can come from .pem file or QR code image
rsa_key_data = None rsa_key_data = None
rsa_key_qr = request.files.get('rsa_key_qr') rsa_key_qr = request.files.get('rsa_key_qr')
rsa_key_from_qr = False rsa_key_from_qr = False
if rsa_key_file and rsa_key_file.filename: if rsa_key_file and rsa_key_file.filename:
rsa_key_data = rsa_key_file.read() rsa_key_data = rsa_key_file.read()
elif rsa_key_qr and rsa_key_qr.filename and HAS_QRCODE_READ: elif rsa_key_qr and rsa_key_qr.filename and HAS_QRCODE_READ:
@@ -1065,30 +1062,30 @@ def decode_page():
else: else:
flash('Could not extract RSA key from QR code image.', 'error') flash('Could not extract RSA key from QR code image.', 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
# Validate security factors # Validate security factors
result = validate_security_factors(pin, rsa_key_data) result = validate_security_factors(pin, rsa_key_data)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
# Validate PIN if provided # Validate PIN if provided
if pin: if pin:
result = validate_pin(pin) result = validate_pin(pin)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
# Determine key password # Determine key password
key_password = None if rsa_key_from_qr else (rsa_password if rsa_password else None) key_password = None if rsa_key_from_qr else (rsa_password if rsa_password else None)
# Validate RSA key if provided # Validate RSA key if provided
if rsa_key_data: if rsa_key_data:
result = validate_rsa_key(rsa_key_data, key_password) result = validate_rsa_key(rsa_key_data, key_password)
if not result.is_valid: if not result.is_valid:
flash(result.error_message, 'error') flash(result.error_message, 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
# v4.0.0: Include channel_key parameter # v4.0.0: Include channel_key parameter
# Use subprocess-isolated decode to prevent crashes # Use subprocess-isolated decode to prevent crashes
decode_result = subprocess_stego.decode( decode_result = subprocess_stego.decode(
@@ -1101,7 +1098,7 @@ def decode_page():
embed_mode=embed_mode, embed_mode=embed_mode,
channel_key=channel_key, # v4.0.0 channel_key=channel_key, # v4.0.0
) )
# Check for subprocess errors # Check for subprocess errors
if not decode_result.success: if not decode_result.success:
error_msg = decode_result.error or 'Decoding failed' error_msg = decode_result.error or 'Decoding failed'
@@ -1112,12 +1109,12 @@ def decode_page():
if 'decrypt' in error_msg.lower() or decode_result.error_type == 'DecryptionError': if 'decrypt' in error_msg.lower() or decode_result.error_type == 'DecryptionError':
raise DecryptionError(error_msg) raise DecryptionError(error_msg)
raise StegasooError(error_msg) raise StegasooError(error_msg)
if decode_result.is_file: if decode_result.is_file:
# File content - store temporarily for download # File content - store temporarily for download
file_id = secrets.token_urlsafe(16) file_id = secrets.token_urlsafe(16)
cleanup_temp_files() cleanup_temp_files()
filename = decode_result.filename or 'decoded_file' filename = decode_result.filename or 'decoded_file'
TEMP_FILES[file_id] = { TEMP_FILES[file_id] = {
'data': decode_result.file_data, 'data': decode_result.file_data,
@@ -1125,7 +1122,7 @@ def decode_page():
'mime_type': decode_result.mime_type, 'mime_type': decode_result.mime_type,
'timestamp': time.time() 'timestamp': time.time()
} }
return render_template('decode.html', return render_template('decode.html',
decoded_file=True, decoded_file=True,
file_id=file_id, file_id=file_id,
@@ -1136,11 +1133,11 @@ def decode_page():
) )
else: else:
# Text content # Text content
return render_template('decode.html', return render_template('decode.html',
decoded_message=decode_result.message, decoded_message=decode_result.message,
has_qrcode_read=HAS_QRCODE_READ has_qrcode_read=HAS_QRCODE_READ
) )
except DecryptionError: except DecryptionError:
flash('Decryption failed. Check your passphrase, PIN, RSA key, reference photo, and channel key.', 'error') flash('Decryption failed. Check your passphrase, PIN, RSA key, reference photo, and channel key.', 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
@@ -1150,7 +1147,7 @@ def decode_page():
except Exception as e: except Exception as e:
flash(f'Error: {e}', 'error') flash(f'Error: {e}', 'error')
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ) return render_template('decode.html', has_qrcode_read=HAS_QRCODE_READ)
@@ -1160,10 +1157,10 @@ def decode_download(file_id):
if file_id not in TEMP_FILES: if file_id not in TEMP_FILES:
flash('File expired or not found.', 'error') flash('File expired or not found.', 'error')
return redirect(url_for('decode_page')) return redirect(url_for('decode_page'))
file_info = TEMP_FILES[file_id] file_info = TEMP_FILES[file_id]
mime_type = file_info.get('mime_type', 'application/octet-stream') mime_type = file_info.get('mime_type', 'application/octet-stream')
return send_file( return send_file(
io.BytesIO(file_info['data']), io.BytesIO(file_info['data']),
mimetype=mime_type, mimetype=mime_type,
@@ -1174,7 +1171,7 @@ def decode_download(file_id):
@app.route('/about') @app.route('/about')
def about(): def about():
return render_template('about.html', return render_template('about.html',
has_argon2=has_argon2(), has_argon2=has_argon2(),
has_qrcode_read=HAS_QRCODE_READ has_qrcode_read=HAS_QRCODE_READ
) )
@@ -1188,7 +1185,7 @@ def test_capacity():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier image provided'}), 400 return jsonify({'error': 'No carrier image provided'}), 400
try: try:
carrier_data = carrier.read() carrier_data = carrier.read()
buffer = io.BytesIO(carrier_data) buffer = io.BytesIO(carrier_data)
@@ -1197,11 +1194,11 @@ def test_capacity():
fmt = img.format fmt = img.format
img.close() img.close()
buffer.close() buffer.close()
pixels = width * height pixels = width * height
lsb_bytes = (pixels * 3) // 8 lsb_bytes = (pixels * 3) // 8
dct_bytes = ((width // 8) * (height // 8) * 16) // 8 - 10 dct_bytes = ((width // 8) * (height // 8) * 16) // 8 - 10
return jsonify({ return jsonify({
'success': True, 'success': True,
'width': width, 'width': width,
@@ -1220,7 +1217,7 @@ def test_capacity_nopil():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier image provided'}), 400 return jsonify({'error': 'No carrier image provided'}), 400
carrier_data = carrier.read() carrier_data = carrier.read()
return jsonify({ return jsonify({
'success': True, 'success': True,

View File

@@ -17,9 +17,9 @@ Usage:
echo '{"operation": "encode", ...}' | python stego_worker.py echo '{"operation": "encode", ...}' | python stego_worker.py
""" """
import sys
import json
import base64 import base64
import json
import sys
import traceback import traceback
from pathlib import Path from pathlib import Path
@@ -31,10 +31,10 @@ sys.path.insert(0, str(Path(__file__).parent))
def _resolve_channel_key(channel_key_param): def _resolve_channel_key(channel_key_param):
""" """
Resolve channel_key parameter to value for stegasoo. Resolve channel_key parameter to value for stegasoo.
Args: Args:
channel_key_param: 'auto', 'none', explicit key, or None channel_key_param: 'auto', 'none', explicit key, or None
Returns: Returns:
None (auto), "" (public), or explicit key string None (auto), "" (public), or explicit key string
""" """
@@ -49,41 +49,41 @@ def _resolve_channel_key(channel_key_param):
def _get_channel_info(resolved_key): def _get_channel_info(resolved_key):
""" """
Get channel mode and fingerprint for response. Get channel mode and fingerprint for response.
Returns: Returns:
(mode, fingerprint) tuple (mode, fingerprint) tuple
""" """
from stegasoo import has_channel_key, get_channel_status from stegasoo import get_channel_status, has_channel_key
if resolved_key == "": if resolved_key == "":
return "public", None return "public", None
if resolved_key is not None: if resolved_key is not None:
# Explicit key # Explicit key
fingerprint = f"{resolved_key[:4]}-••••-••••-••••-••••-••••-••••-{resolved_key[-4:]}" fingerprint = f"{resolved_key[:4]}-••••-••••-••••-••••-••••-••••-{resolved_key[-4:]}"
return "private", fingerprint return "private", fingerprint
# Auto mode - check server config # Auto mode - check server config
if has_channel_key(): if has_channel_key():
status = get_channel_status() status = get_channel_status()
return "private", status.get('fingerprint') return "private", status.get('fingerprint')
return "public", None return "public", None
def encode_operation(params: dict) -> dict: def encode_operation(params: dict) -> dict:
"""Handle encode operation.""" """Handle encode operation."""
from stegasoo import encode, FilePayload from stegasoo import FilePayload, encode
# Decode base64 inputs # Decode base64 inputs
carrier_data = base64.b64decode(params['carrier_b64']) carrier_data = base64.b64decode(params['carrier_b64'])
reference_data = base64.b64decode(params['reference_b64']) reference_data = base64.b64decode(params['reference_b64'])
# Optional RSA key # Optional RSA key
rsa_key_data = None rsa_key_data = None
if params.get('rsa_key_b64'): if params.get('rsa_key_b64'):
rsa_key_data = base64.b64decode(params['rsa_key_b64']) rsa_key_data = base64.b64decode(params['rsa_key_b64'])
# Determine payload type # Determine payload type
if params.get('file_b64'): if params.get('file_b64'):
file_data = base64.b64decode(params['file_b64']) file_data = base64.b64decode(params['file_b64'])
@@ -94,10 +94,10 @@ def encode_operation(params: dict) -> dict:
) )
else: else:
payload = params.get('message', '') payload = params.get('message', '')
# Resolve channel key (v4.0.0) # Resolve channel key (v4.0.0)
resolved_channel_key = _resolve_channel_key(params.get('channel_key', 'auto')) resolved_channel_key = _resolve_channel_key(params.get('channel_key', 'auto'))
# Call encode with correct parameter names # Call encode with correct parameter names
result = encode( result = encode(
message=payload, message=payload,
@@ -112,7 +112,7 @@ def encode_operation(params: dict) -> dict:
dct_color_mode=params.get('dct_color_mode', 'color'), dct_color_mode=params.get('dct_color_mode', 'color'),
channel_key=resolved_channel_key, # v4.0.0 channel_key=resolved_channel_key, # v4.0.0
) )
# Build stats dict if available # Build stats dict if available
stats = None stats = None
if hasattr(result, 'stats') and result.stats: if hasattr(result, 'stats') and result.stats:
@@ -121,10 +121,10 @@ def encode_operation(params: dict) -> dict:
'capacity_used': getattr(result.stats, 'capacity_used', 0), 'capacity_used': getattr(result.stats, 'capacity_used', 0),
'bytes_embedded': getattr(result.stats, 'bytes_embedded', 0), 'bytes_embedded': getattr(result.stats, 'bytes_embedded', 0),
} }
# Get channel info for response (v4.0.0) # Get channel info for response (v4.0.0)
channel_mode, channel_fingerprint = _get_channel_info(resolved_channel_key) channel_mode, channel_fingerprint = _get_channel_info(resolved_channel_key)
return { return {
'success': True, 'success': True,
'stego_b64': base64.b64encode(result.stego_image).decode('ascii'), 'stego_b64': base64.b64encode(result.stego_image).decode('ascii'),
@@ -138,19 +138,19 @@ def encode_operation(params: dict) -> dict:
def decode_operation(params: dict) -> dict: def decode_operation(params: dict) -> dict:
"""Handle decode operation.""" """Handle decode operation."""
from stegasoo import decode from stegasoo import decode
# Decode base64 inputs # Decode base64 inputs
stego_data = base64.b64decode(params['stego_b64']) stego_data = base64.b64decode(params['stego_b64'])
reference_data = base64.b64decode(params['reference_b64']) reference_data = base64.b64decode(params['reference_b64'])
# Optional RSA key # Optional RSA key
rsa_key_data = None rsa_key_data = None
if params.get('rsa_key_b64'): if params.get('rsa_key_b64'):
rsa_key_data = base64.b64decode(params['rsa_key_b64']) rsa_key_data = base64.b64decode(params['rsa_key_b64'])
# Resolve channel key (v4.0.0) # Resolve channel key (v4.0.0)
resolved_channel_key = _resolve_channel_key(params.get('channel_key', 'auto')) resolved_channel_key = _resolve_channel_key(params.get('channel_key', 'auto'))
# Call decode with correct parameter names # Call decode with correct parameter names
result = decode( result = decode(
stego_image=stego_data, stego_image=stego_data,
@@ -162,7 +162,7 @@ def decode_operation(params: dict) -> dict:
embed_mode=params.get('embed_mode', 'auto'), embed_mode=params.get('embed_mode', 'auto'),
channel_key=resolved_channel_key, # v4.0.0 channel_key=resolved_channel_key, # v4.0.0
) )
if result.is_file: if result.is_file:
return { return {
'success': True, 'success': True,
@@ -182,10 +182,10 @@ def decode_operation(params: dict) -> dict:
def compare_operation(params: dict) -> dict: def compare_operation(params: dict) -> dict:
"""Handle compare_modes operation.""" """Handle compare_modes operation."""
from stegasoo import compare_modes from stegasoo import compare_modes
carrier_data = base64.b64decode(params['carrier_b64']) carrier_data = base64.b64decode(params['carrier_b64'])
result = compare_modes(carrier_data) result = compare_modes(carrier_data)
return { return {
'success': True, 'success': True,
'comparison': result, 'comparison': result,
@@ -195,15 +195,15 @@ def compare_operation(params: dict) -> dict:
def capacity_check_operation(params: dict) -> dict: def capacity_check_operation(params: dict) -> dict:
"""Handle will_fit_by_mode operation.""" """Handle will_fit_by_mode operation."""
from stegasoo import will_fit_by_mode from stegasoo import will_fit_by_mode
carrier_data = base64.b64decode(params['carrier_b64']) carrier_data = base64.b64decode(params['carrier_b64'])
result = will_fit_by_mode( result = will_fit_by_mode(
payload=params['payload_size'], payload=params['payload_size'],
carrier_image=carrier_data, carrier_image=carrier_data,
embed_mode=params.get('embed_mode', 'lsb'), embed_mode=params.get('embed_mode', 'lsb'),
) )
return { return {
'success': True, 'success': True,
'result': result, 'result': result,
@@ -213,10 +213,10 @@ def capacity_check_operation(params: dict) -> dict:
def channel_status_operation(params: dict) -> dict: def channel_status_operation(params: dict) -> dict:
"""Handle channel status check (v4.0.0).""" """Handle channel status check (v4.0.0)."""
from stegasoo import get_channel_status from stegasoo import get_channel_status
status = get_channel_status() status = get_channel_status()
reveal = params.get('reveal', False) reveal = params.get('reveal', False)
return { return {
'success': True, 'success': True,
'status': { 'status': {
@@ -234,13 +234,13 @@ def main():
try: try:
# Read all input # Read all input
input_text = sys.stdin.read() input_text = sys.stdin.read()
if not input_text.strip(): if not input_text.strip():
output = {'success': False, 'error': 'No input provided'} output = {'success': False, 'error': 'No input provided'}
else: else:
params = json.loads(input_text) params = json.loads(input_text)
operation = params.get('operation') operation = params.get('operation')
if operation == 'encode': if operation == 'encode':
output = encode_operation(params) output = encode_operation(params)
elif operation == 'decode': elif operation == 'decode':
@@ -253,7 +253,7 @@ def main():
output = channel_status_operation(params) output = channel_status_operation(params)
else: else:
output = {'success': False, 'error': f'Unknown operation: {operation}'} output = {'success': False, 'error': f'Unknown operation: {operation}'}
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
output = {'success': False, 'error': f'Invalid JSON: {e}'} output = {'success': False, 'error': f'Invalid JSON: {e}'}
except Exception as e: except Exception as e:
@@ -263,7 +263,7 @@ def main():
'error_type': type(e).__name__, 'error_type': type(e).__name__,
'traceback': traceback.format_exc(), 'traceback': traceback.format_exc(),
} }
# Write output as JSON # Write output as JSON
print(json.dumps(output), flush=True) print(json.dumps(output), flush=True)

View File

@@ -10,9 +10,9 @@ CHANGES in v4.0.0:
Usage: Usage:
from subprocess_stego import SubprocessStego from subprocess_stego import SubprocessStego
stego = SubprocessStego() stego = SubprocessStego()
# Encode with channel key # Encode with channel key
result = stego.encode( result = stego.encode(
carrier_data=carrier_bytes, carrier_data=carrier_bytes,
@@ -23,13 +23,13 @@ Usage:
embed_mode="dct", embed_mode="dct",
channel_key="auto", # or "none", or explicit key channel_key="auto", # or "none", or explicit key
) )
if result.success: if result.success:
stego_bytes = result.stego_data stego_bytes = result.stego_data
extension = result.extension extension = result.extension
else: else:
error_message = result.error error_message = result.error
# Decode # Decode
result = stego.decode( result = stego.decode(
stego_data=stego_bytes, stego_data=stego_bytes,
@@ -38,19 +38,18 @@ Usage:
pin="123456", pin="123456",
channel_key="auto", channel_key="auto",
) )
# Compare modes (capacity) # Compare modes (capacity)
result = stego.compare_modes(carrier_bytes) result = stego.compare_modes(carrier_bytes)
""" """
import json
import base64 import base64
import json
import subprocess import subprocess
import sys import sys
from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Dict, Any, Union from pathlib import Path
from typing import Any
# Default timeout for operations (seconds) # Default timeout for operations (seconds)
DEFAULT_TIMEOUT = 120 DEFAULT_TIMEOUT = 120
@@ -63,14 +62,14 @@ WORKER_SCRIPT = Path(__file__).parent / 'stego_worker.py'
class EncodeResult: class EncodeResult:
"""Result from encode operation.""" """Result from encode operation."""
success: bool success: bool
stego_data: Optional[bytes] = None stego_data: bytes | None = None
filename: Optional[str] = None filename: str | None = None
stats: Optional[Dict[str, Any]] = None stats: dict[str, Any] | None = None
# Channel info (v4.0.0) # Channel info (v4.0.0)
channel_mode: Optional[str] = None channel_mode: str | None = None
channel_fingerprint: Optional[str] = None channel_fingerprint: str | None = None
error: Optional[str] = None error: str | None = None
error_type: Optional[str] = None error_type: str | None = None
@dataclass @dataclass
@@ -78,12 +77,12 @@ class DecodeResult:
"""Result from decode operation.""" """Result from decode operation."""
success: bool success: bool
is_file: bool = False is_file: bool = False
message: Optional[str] = None message: str | None = None
file_data: Optional[bytes] = None file_data: bytes | None = None
filename: Optional[str] = None filename: str | None = None
mime_type: Optional[str] = None mime_type: str | None = None
error: Optional[str] = None error: str | None = None
error_type: Optional[str] = None error_type: str | None = None
@dataclass @dataclass
@@ -92,9 +91,9 @@ class CompareResult:
success: bool success: bool
width: int = 0 width: int = 0
height: int = 0 height: int = 0
lsb: Optional[Dict[str, Any]] = None lsb: dict[str, Any] | None = None
dct: Optional[Dict[str, Any]] = None dct: dict[str, Any] | None = None
error: Optional[str] = None error: str | None = None
@dataclass @dataclass
@@ -107,38 +106,38 @@ class CapacityResult:
usage_percent: float = 0.0 usage_percent: float = 0.0
headroom: int = 0 headroom: int = 0
mode: str = "" mode: str = ""
error: Optional[str] = None error: str | None = None
@dataclass @dataclass
class ChannelStatusResult: class ChannelStatusResult:
"""Result from channel status check (v4.0.0).""" """Result from channel status check (v4.0.0)."""
success: bool success: bool
mode: str = "public" mode: str = "public"
configured: bool = False configured: bool = False
fingerprint: Optional[str] = None fingerprint: str | None = None
source: Optional[str] = None source: str | None = None
key: Optional[str] = None key: str | None = None
error: Optional[str] = None error: str | None = None
class SubprocessStego: class SubprocessStego:
""" """
Subprocess-isolated steganography operations. Subprocess-isolated steganography operations.
All operations run in a separate Python process. If jpegio or scipy All operations run in a separate Python process. If jpegio or scipy
crashes, only the subprocess dies - Flask keeps running. crashes, only the subprocess dies - Flask keeps running.
""" """
def __init__( def __init__(
self, self,
worker_path: Optional[Path] = None, worker_path: Path | None = None,
python_executable: Optional[str] = None, python_executable: str | None = None,
timeout: int = DEFAULT_TIMEOUT, timeout: int = DEFAULT_TIMEOUT,
): ):
""" """
Initialize subprocess wrapper. Initialize subprocess wrapper.
Args: Args:
worker_path: Path to stego_worker.py (default: same directory) worker_path: Path to stego_worker.py (default: same directory)
python_executable: Python interpreter to use (default: same as current) python_executable: Python interpreter to use (default: same as current)
@@ -147,24 +146,24 @@ class SubprocessStego:
self.worker_path = worker_path or WORKER_SCRIPT self.worker_path = worker_path or WORKER_SCRIPT
self.python = python_executable or sys.executable self.python = python_executable or sys.executable
self.timeout = timeout self.timeout = timeout
if not self.worker_path.exists(): if not self.worker_path.exists():
raise FileNotFoundError(f"Worker script not found: {self.worker_path}") raise FileNotFoundError(f"Worker script not found: {self.worker_path}")
def _run_worker(self, params: Dict[str, Any], timeout: Optional[int] = None) -> Dict[str, Any]: def _run_worker(self, params: dict[str, Any], timeout: int | None = None) -> dict[str, Any]:
""" """
Run the worker subprocess with given parameters. Run the worker subprocess with given parameters.
Args: Args:
params: Dictionary of parameters (will be JSON-encoded) params: Dictionary of parameters (will be JSON-encoded)
timeout: Operation timeout in seconds timeout: Operation timeout in seconds
Returns: Returns:
Dictionary with results from worker Dictionary with results from worker
""" """
timeout = timeout or self.timeout timeout = timeout or self.timeout
input_json = json.dumps(params) input_json = json.dumps(params)
try: try:
result = subprocess.run( result = subprocess.run(
[self.python, str(self.worker_path)], [self.python, str(self.worker_path)],
@@ -174,7 +173,7 @@ class SubprocessStego:
timeout=timeout, timeout=timeout,
cwd=str(self.worker_path.parent), cwd=str(self.worker_path.parent),
) )
if result.returncode != 0: if result.returncode != 0:
# Worker crashed # Worker crashed
return { return {
@@ -182,16 +181,16 @@ class SubprocessStego:
'error': f'Worker crashed (exit code {result.returncode})', 'error': f'Worker crashed (exit code {result.returncode})',
'stderr': result.stderr, 'stderr': result.stderr,
} }
if not result.stdout.strip(): if not result.stdout.strip():
return { return {
'success': False, 'success': False,
'error': 'Worker returned empty output', 'error': 'Worker returned empty output',
'stderr': result.stderr, 'stderr': result.stderr,
} }
return json.loads(result.stdout) return json.loads(result.stdout)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return { return {
'success': False, 'success': False,
@@ -210,29 +209,29 @@ class SubprocessStego:
'error': str(e), 'error': str(e),
'error_type': type(e).__name__, 'error_type': type(e).__name__,
} }
def encode( def encode(
self, self,
carrier_data: bytes, carrier_data: bytes,
reference_data: bytes, reference_data: bytes,
message: Optional[str] = None, message: str | None = None,
file_data: Optional[bytes] = None, file_data: bytes | None = None,
file_name: Optional[str] = None, file_name: str | None = None,
file_mime: Optional[str] = None, file_mime: str | None = None,
passphrase: str = "", passphrase: str = "",
pin: Optional[str] = None, pin: str | None = None,
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
embed_mode: str = "lsb", embed_mode: str = "lsb",
dct_output_format: str = "png", dct_output_format: str = "png",
dct_color_mode: str = "color", dct_color_mode: str = "color",
# Channel key (v4.0.0) # Channel key (v4.0.0)
channel_key: Optional[str] = "auto", channel_key: str | None = "auto",
timeout: Optional[int] = None, timeout: int | None = None,
) -> EncodeResult: ) -> EncodeResult:
""" """
Encode a message or file into an image. Encode a message or file into an image.
Args: Args:
carrier_data: Carrier image bytes carrier_data: Carrier image bytes
reference_data: Reference photo bytes reference_data: Reference photo bytes
@@ -249,7 +248,7 @@ class SubprocessStego:
dct_color_mode: 'grayscale' or 'color' (for DCT mode) dct_color_mode: 'grayscale' or 'color' (for DCT mode)
channel_key: 'auto' (server config), 'none' (public), or explicit key (v4.0.0) channel_key: 'auto' (server config), 'none' (public), or explicit key (v4.0.0)
timeout: Operation timeout in seconds timeout: Operation timeout in seconds
Returns: Returns:
EncodeResult with stego_data and extension on success EncodeResult with stego_data and extension on success
""" """
@@ -265,18 +264,18 @@ class SubprocessStego:
'dct_color_mode': dct_color_mode, 'dct_color_mode': dct_color_mode,
'channel_key': channel_key, # v4.0.0 'channel_key': channel_key, # v4.0.0
} }
if file_data: if file_data:
params['file_b64'] = base64.b64encode(file_data).decode('ascii') params['file_b64'] = base64.b64encode(file_data).decode('ascii')
params['file_name'] = file_name params['file_name'] = file_name
params['file_mime'] = file_mime params['file_mime'] = file_mime
if rsa_key_data: if rsa_key_data:
params['rsa_key_b64'] = base64.b64encode(rsa_key_data).decode('ascii') params['rsa_key_b64'] = base64.b64encode(rsa_key_data).decode('ascii')
params['rsa_password'] = rsa_password params['rsa_password'] = rsa_password
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get('success'):
return EncodeResult( return EncodeResult(
success=True, success=True,
@@ -292,23 +291,23 @@ class SubprocessStego:
error=result.get('error', 'Unknown error'), error=result.get('error', 'Unknown error'),
error_type=result.get('error_type'), error_type=result.get('error_type'),
) )
def decode( def decode(
self, self,
stego_data: bytes, stego_data: bytes,
reference_data: bytes, reference_data: bytes,
passphrase: str = "", passphrase: str = "",
pin: Optional[str] = None, pin: str | None = None,
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
embed_mode: str = "auto", embed_mode: str = "auto",
# Channel key (v4.0.0) # Channel key (v4.0.0)
channel_key: Optional[str] = "auto", channel_key: str | None = "auto",
timeout: Optional[int] = None, timeout: int | None = None,
) -> DecodeResult: ) -> DecodeResult:
""" """
Decode a message or file from a stego image. Decode a message or file from a stego image.
Args: Args:
stego_data: Stego image bytes stego_data: Stego image bytes
reference_data: Reference photo bytes reference_data: Reference photo bytes
@@ -319,7 +318,7 @@ class SubprocessStego:
embed_mode: 'auto', 'lsb', or 'dct' embed_mode: 'auto', 'lsb', or 'dct'
channel_key: 'auto' (server config), 'none' (public), or explicit key (v4.0.0) channel_key: 'auto' (server config), 'none' (public), or explicit key (v4.0.0)
timeout: Operation timeout in seconds timeout: Operation timeout in seconds
Returns: Returns:
DecodeResult with message or file_data on success DecodeResult with message or file_data on success
""" """
@@ -332,13 +331,13 @@ class SubprocessStego:
'embed_mode': embed_mode, 'embed_mode': embed_mode,
'channel_key': channel_key, # v4.0.0 'channel_key': channel_key, # v4.0.0
} }
if rsa_key_data: if rsa_key_data:
params['rsa_key_b64'] = base64.b64encode(rsa_key_data).decode('ascii') params['rsa_key_b64'] = base64.b64encode(rsa_key_data).decode('ascii')
params['rsa_password'] = rsa_password params['rsa_password'] = rsa_password
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get('success'):
if result.get('is_file'): if result.get('is_file'):
return DecodeResult( return DecodeResult(
@@ -360,19 +359,19 @@ class SubprocessStego:
error=result.get('error', 'Unknown error'), error=result.get('error', 'Unknown error'),
error_type=result.get('error_type'), error_type=result.get('error_type'),
) )
def compare_modes( def compare_modes(
self, self,
carrier_data: bytes, carrier_data: bytes,
timeout: Optional[int] = None, timeout: int | None = None,
) -> CompareResult: ) -> CompareResult:
""" """
Compare LSB and DCT capacity for a carrier image. Compare LSB and DCT capacity for a carrier image.
Args: Args:
carrier_data: Carrier image bytes carrier_data: Carrier image bytes
timeout: Operation timeout in seconds timeout: Operation timeout in seconds
Returns: Returns:
CompareResult with capacity information CompareResult with capacity information
""" """
@@ -380,9 +379,9 @@ class SubprocessStego:
'operation': 'compare', 'operation': 'compare',
'carrier_b64': base64.b64encode(carrier_data).decode('ascii'), 'carrier_b64': base64.b64encode(carrier_data).decode('ascii'),
} }
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get('success'):
comparison = result.get('comparison', {}) comparison = result.get('comparison', {})
return CompareResult( return CompareResult(
@@ -397,23 +396,23 @@ class SubprocessStego:
success=False, success=False,
error=result.get('error', 'Unknown error'), error=result.get('error', 'Unknown error'),
) )
def check_capacity( def check_capacity(
self, self,
carrier_data: bytes, carrier_data: bytes,
payload_size: int, payload_size: int,
embed_mode: str = "lsb", embed_mode: str = "lsb",
timeout: Optional[int] = None, timeout: int | None = None,
) -> CapacityResult: ) -> CapacityResult:
""" """
Check if a payload will fit in the carrier. Check if a payload will fit in the carrier.
Args: Args:
carrier_data: Carrier image bytes carrier_data: Carrier image bytes
payload_size: Size of payload in bytes payload_size: Size of payload in bytes
embed_mode: 'lsb' or 'dct' embed_mode: 'lsb' or 'dct'
timeout: Operation timeout in seconds timeout: Operation timeout in seconds
Returns: Returns:
CapacityResult with fit information CapacityResult with fit information
""" """
@@ -423,9 +422,9 @@ class SubprocessStego:
'payload_size': payload_size, 'payload_size': payload_size,
'embed_mode': embed_mode, 'embed_mode': embed_mode,
} }
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get('success'):
r = result.get('result', {}) r = result.get('result', {})
return CapacityResult( return CapacityResult(
@@ -442,19 +441,19 @@ class SubprocessStego:
success=False, success=False,
error=result.get('error', 'Unknown error'), error=result.get('error', 'Unknown error'),
) )
def get_channel_status( def get_channel_status(
self, self,
reveal: bool = False, reveal: bool = False,
timeout: Optional[int] = None, timeout: int | None = None,
) -> ChannelStatusResult: ) -> ChannelStatusResult:
""" """
Get current channel key status (v4.0.0). Get current channel key status (v4.0.0).
Args: Args:
reveal: Include full key in response reveal: Include full key in response
timeout: Operation timeout in seconds timeout: Operation timeout in seconds
Returns: Returns:
ChannelStatusResult with channel info ChannelStatusResult with channel info
""" """
@@ -462,9 +461,9 @@ class SubprocessStego:
'operation': 'channel_status', 'operation': 'channel_status',
'reveal': reveal, 'reveal': reveal,
} }
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get('success'):
status = result.get('status', {}) status = result.get('status', {})
return ChannelStatusResult( return ChannelStatusResult(
@@ -483,7 +482,7 @@ class SubprocessStego:
# Convenience function for quick usage # Convenience function for quick usage
_default_stego: Optional[SubprocessStego] = None _default_stego: SubprocessStego | None = None
def get_subprocess_stego() -> SubprocessStego: def get_subprocess_stego() -> SubprocessStego:

View File

@@ -1,90 +0,0 @@
"""
Minimal test to isolate the memory corruption crash.
Add this route to your app.py temporarily to test if the crash
is in Flask/Pillow or in stegasoo code.
Usage:
1. Add this code to app.py
2. Restart the server
3. Use the /test-capacity endpoint instead of /api/compare-capacity
4. If it crashes: Flask or Pillow issue
5. If it works: Stegasoo code issue
"""
# Add these imports at the top of app.py if not present:
# from PIL import Image
# import io
# Add this route to app.py:
@app.route('/test-capacity', methods=['POST'])
def test_capacity():
"""
Minimal capacity test - no stegasoo code, just PIL.
"""
carrier = request.files.get('carrier')
if not carrier:
return jsonify({'error': 'No carrier image provided'}), 400
try:
# Read the file data
carrier_data = carrier.read()
# Method 1: Just get size from PIL
buffer = io.BytesIO(carrier_data)
img = Image.open(buffer)
width, height = img.size
fmt = img.format
mode = img.mode
img.close()
buffer.close()
# Simple capacity calculation (no scipy, no numpy)
pixels = width * height
lsb_bytes = (pixels * 3) // 8
blocks = (width // 8) * (height // 8)
dct_bytes = (blocks * 16) // 8 - 10
return jsonify({
'success': True,
'width': width,
'height': height,
'format': fmt,
'mode': mode,
'lsb': {
'capacity_bytes': lsb_bytes,
'capacity_kb': round(lsb_bytes / 1024, 1),
},
'dct': {
'capacity_bytes': dct_bytes,
'capacity_kb': round(dct_bytes / 1024, 1),
}
})
except Exception as e:
import traceback
return jsonify({'error': str(e), 'trace': traceback.format_exc()}), 500
# Alternative: completely bypass PIL too
@app.route('/test-capacity-nopil', methods=['POST'])
def test_capacity_nopil():
"""
Ultra-minimal test - no PIL, no stegasoo.
"""
carrier = request.files.get('carrier')
if not carrier:
return jsonify({'error': 'No carrier image provided'}), 400
try:
carrier_data = carrier.read()
# Just return size info, no image processing at all
return jsonify({
'success': True,
'data_size': len(carrier_data),
'first_bytes': carrier_data[:20].hex() if len(carrier_data) >= 20 else carrier_data.hex(),
})
except Exception as e:
import traceback
return jsonify({'error': str(e), 'trace': traceback.format_exc()}), 500

View File

@@ -39,19 +39,19 @@ def test1_pil_only():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier'}), 400 return jsonify({'error': 'No carrier'}), 400
data = carrier.read() data = carrier.read()
print(f"[test1] Read {len(data)} bytes") print(f"[test1] Read {len(data)} bytes")
img = Image.open(io.BytesIO(data)) img = Image.open(io.BytesIO(data))
width, height = img.size width, height = img.size
fmt = img.format fmt = img.format
img.close() img.close()
print(f"[test1] Image: {width}x{height} {fmt}") print(f"[test1] Image: {width}x{height} {fmt}")
gc.collect() gc.collect()
print("[test1] Returning response...") print("[test1] Returning response...")
return jsonify({ return jsonify({
'test': 'pil_only', 'test': 'pil_only',
'width': width, 'width': width,
@@ -66,31 +66,31 @@ def test2_multiple_opens():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier'}), 400 return jsonify({'error': 'No carrier'}), 400
data = carrier.read() data = carrier.read()
print(f"[test2] Read {len(data)} bytes") print(f"[test2] Read {len(data)} bytes")
# First open # First open
img1 = Image.open(io.BytesIO(data)) img1 = Image.open(io.BytesIO(data))
width, height = img1.size width, height = img1.size
img1.close() img1.close()
print(f"[test2] Open 1: {width}x{height}") print(f"[test2] Open 1: {width}x{height}")
# Second open # Second open
img2 = Image.open(io.BytesIO(data)) img2 = Image.open(io.BytesIO(data))
pixels = img2.size[0] * img2.size[1] pixels = img2.size[0] * img2.size[1]
img2.close() img2.close()
print(f"[test2] Open 2: {pixels} pixels") print(f"[test2] Open 2: {pixels} pixels")
# Third open # Third open
img3 = Image.open(io.BytesIO(data)) img3 = Image.open(io.BytesIO(data))
blocks = (img3.size[0] // 8) * (img3.size[1] // 8) blocks = (img3.size[0] // 8) * (img3.size[1] // 8)
img3.close() img3.close()
print(f"[test2] Open 3: {blocks} blocks") print(f"[test2] Open 3: {blocks} blocks")
gc.collect() gc.collect()
print("[test2] Returning response...") print("[test2] Returning response...")
return jsonify({ return jsonify({
'test': 'multiple_opens', 'test': 'multiple_opens',
'width': width, 'width': width,
@@ -105,39 +105,39 @@ def test3_with_jpegio():
"""Test 3: Include jpegio operations""" """Test 3: Include jpegio operations"""
if not HAS_JPEGIO: if not HAS_JPEGIO:
return jsonify({'error': 'jpegio not available'}), 501 return jsonify({'error': 'jpegio not available'}), 501
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier'}), 400 return jsonify({'error': 'No carrier'}), 400
data = carrier.read() data = carrier.read()
print(f"[test3] Read {len(data)} bytes") print(f"[test3] Read {len(data)} bytes")
# Check if JPEG # Check if JPEG
img = Image.open(io.BytesIO(data)) img = Image.open(io.BytesIO(data))
is_jpeg = img.format == 'JPEG' is_jpeg = img.format == 'JPEG'
width, height = img.size width, height = img.size
img.close() img.close()
print(f"[test3] Image: {width}x{height}, JPEG: {is_jpeg}") print(f"[test3] Image: {width}x{height}, JPEG: {is_jpeg}")
if not is_jpeg: if not is_jpeg:
return jsonify({'error': 'Not a JPEG'}), 400 return jsonify({'error': 'Not a JPEG'}), 400
# Write to temp file # Write to temp file
fd, temp_path = tempfile.mkstemp(suffix='.jpg') fd, temp_path = tempfile.mkstemp(suffix='.jpg')
os.write(fd, data) os.write(fd, data)
os.close(fd) os.close(fd)
print(f"[test3] Temp file: {temp_path}") print(f"[test3] Temp file: {temp_path}")
try: try:
# Read with jpegio # Read with jpegio
jpeg = jio.read(temp_path) jpeg = jio.read(temp_path)
print(f"[test3] jpegio.read() OK") print(f"[test3] jpegio.read() OK")
coef = jpeg.coef_arrays[0] coef = jpeg.coef_arrays[0]
coef_shape = coef.shape coef_shape = coef.shape
print(f"[test3] Coef shape: {coef_shape}") print(f"[test3] Coef shape: {coef_shape}")
# Count positions like the real code does # Count positions like the real code does
positions = 0 positions = 0
h, w = coef.shape h, w = coef.shape
@@ -148,19 +148,19 @@ def test3_with_jpegio():
if abs(coef[row, col]) >= 2: if abs(coef[row, col]) >= 2:
positions += 1 positions += 1
print(f"[test3] Usable positions: {positions}") print(f"[test3] Usable positions: {positions}")
# Cleanup # Cleanup
del coef del coef
del jpeg del jpeg
print(f"[test3] Deleted jpegio objects") print(f"[test3] Deleted jpegio objects")
finally: finally:
os.unlink(temp_path) os.unlink(temp_path)
print(f"[test3] Removed temp file") print(f"[test3] Removed temp file")
gc.collect() gc.collect()
print("[test3] Returning response...") print("[test3] Returning response...")
return jsonify({ return jsonify({
'test': 'with_jpegio', 'test': 'with_jpegio',
'width': width, 'width': width,
@@ -176,34 +176,34 @@ def test4_numpy_array_from_pil():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier'}), 400 return jsonify({'error': 'No carrier'}), 400
data = carrier.read() data = carrier.read()
print(f"[test4] Read {len(data)} bytes") print(f"[test4] Read {len(data)} bytes")
img = Image.open(io.BytesIO(data)) img = Image.open(io.BytesIO(data))
width, height = img.size width, height = img.size
print(f"[test4] Image: {width}x{height}") print(f"[test4] Image: {width}x{height}")
# Convert to grayscale and numpy array # Convert to grayscale and numpy array
gray = img.convert('L') gray = img.convert('L')
arr = np.array(gray, dtype=np.float64, copy=True) arr = np.array(gray, dtype=np.float64, copy=True)
print(f"[test4] Array: {arr.shape} {arr.dtype}") print(f"[test4] Array: {arr.shape} {arr.dtype}")
# Close PIL images # Close PIL images
gray.close() gray.close()
img.close() img.close()
print(f"[test4] PIL closed") print(f"[test4] PIL closed")
# Do some numpy operations # Do some numpy operations
mean_val = float(np.mean(arr)) mean_val = float(np.mean(arr))
std_val = float(np.std(arr)) std_val = float(np.std(arr))
print(f"[test4] Stats: mean={mean_val:.2f}, std={std_val:.2f}") print(f"[test4] Stats: mean={mean_val:.2f}, std={std_val:.2f}")
# Clear array # Clear array
del arr del arr
gc.collect() gc.collect()
print("[test4] Returning response...") print("[test4] Returning response...")
return jsonify({ return jsonify({
'test': 'numpy_from_pil', 'test': 'numpy_from_pil',
'width': width, 'width': width,
@@ -219,32 +219,32 @@ def test5_file_read_keep_reference():
carrier = request.files.get('carrier') carrier = request.files.get('carrier')
if not carrier: if not carrier:
return jsonify({'error': 'No carrier'}), 400 return jsonify({'error': 'No carrier'}), 400
# Don't read into local variable - read directly each time # Don't read into local variable - read directly each time
# This mimics potential issues with Flask's file handling # This mimics potential issues with Flask's file handling
print(f"[test5] File object: {carrier}") print(f"[test5] File object: {carrier}")
# Read once # Read once
carrier.seek(0) carrier.seek(0)
data1 = carrier.read() data1 = carrier.read()
print(f"[test5] First read: {len(data1)} bytes") print(f"[test5] First read: {len(data1)} bytes")
img = Image.open(io.BytesIO(data1)) img = Image.open(io.BytesIO(data1))
width, height = img.size width, height = img.size
img.close() img.close()
# Try to read again (should be empty or need seek) # Try to read again (should be empty or need seek)
data2 = carrier.read() data2 = carrier.read()
print(f"[test5] Second read (no seek): {len(data2)} bytes") print(f"[test5] Second read (no seek): {len(data2)} bytes")
carrier.seek(0) carrier.seek(0)
data3 = carrier.read() data3 = carrier.read()
print(f"[test5] Third read (after seek): {len(data3)} bytes") print(f"[test5] Third read (after seek): {len(data3)} bytes")
gc.collect() gc.collect()
print("[test5] Returning response...") print("[test5] Returning response...")
return jsonify({ return jsonify({
'test': 'file_handling', 'test': 'file_handling',
'width': width, 'width': width,
@@ -285,5 +285,5 @@ if __name__ == '__main__':
print("\nUsage:") print("\nUsage:")
print(' curl -X POST -F "carrier=@xx_2.jpg" http://localhost:5001/test1') print(' curl -X POST -F "carrier=@xx_2.jpg" http://localhost:5001/test1')
print("=" * 60 + "\n") print("=" * 60 + "\n")
app.run(host='0.0.0.0', port=5001, debug=False, threaded=False) app.run(host='0.0.0.0', port=5001, debug=False, threaded=False)

View File

@@ -114,9 +114,18 @@ target-version = ["py310", "py311", "py312"]
[tool.ruff] [tool.ruff]
line-length = 100 line-length = 100
exclude = ["frontends/web/test_routes.py"] # Debug snippet, not a real module
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"] select = ["E", "F", "I", "N", "W", "UP"]
ignore = ["E501"] ignore = ["E501"]
[tool.ruff.lint.per-file-ignores]
# YCbCr colorspace variables (R, G, B, Y, Cb, Cr) are standard names
"src/stegasoo/dct_steganography.py" = ["N803", "N806"]
# Package __init__.py has imports after try/except and aliases - intentional structure
"src/stegasoo/__init__.py" = ["E402"]
[tool.mypy] [tool.mypy]
python_version = "3.10" python_version = "3.10"
warn_return_any = true warn_return_any = true

View File

@@ -17,7 +17,7 @@ import sys
def main(): def main():
""" """
Main entry point for Stegasoo CLI. Main entry point for Stegasoo CLI.
Delegates to the CLI module for command parsing and execution. Delegates to the CLI module for command parsing and execution.
""" """
try: try:

View File

@@ -10,56 +10,55 @@ Changes in v4.0.0:
__version__ = "4.0.1" __version__ = "4.0.1"
# Core functionality # Core functionality
from .encode import encode # Channel key management (v4.0.0)
from .channel import (
clear_channel_key,
format_channel_key,
generate_channel_key,
get_channel_key,
get_channel_status,
has_channel_key,
set_channel_key,
validate_channel_key,
)
# Crypto functions
from .crypto import get_active_channel_key, get_channel_fingerprint, has_argon2
from .decode import decode, decode_file, decode_text from .decode import decode, decode_file, decode_text
from .encode import encode
# Credential generation # Credential generation
from .generate import ( from .generate import (
generate_pin,
generate_passphrase,
generate_rsa_key,
generate_credentials,
export_rsa_key_pem, export_rsa_key_pem,
generate_credentials,
generate_passphrase,
generate_pin,
generate_rsa_key,
load_rsa_key, load_rsa_key,
) )
# Image utilities # Image utilities
from .image_utils import ( from .image_utils import (
get_image_info,
compare_capacity, compare_capacity,
get_image_info,
)
# Steganography functions
from .steganography import (
compare_modes,
has_dct_support,
will_fit_by_mode,
) )
# Utilities # Utilities
from .utils import generate_filename from .utils import generate_filename
# Crypto functions
from .crypto import has_argon2, get_active_channel_key, get_channel_fingerprint
# Channel key management (v4.0.0)
from .channel import (
generate_channel_key,
get_channel_key,
set_channel_key,
clear_channel_key,
has_channel_key,
get_channel_status,
validate_channel_key,
format_channel_key,
)
# Steganography functions
from .steganography import (
has_dct_support,
compare_modes,
will_fit_by_mode,
)
# QR Code utilities - optional, may not be available # QR Code utilities - optional, may not be available
try: try:
from .qr_utils import ( from .qr_utils import (
generate_qr_code,
extract_key_from_qr,
detect_and_crop_qr, detect_and_crop_qr,
extract_key_from_qr,
generate_qr_code,
) )
HAS_QR_UTILS = True HAS_QR_UTILS = True
except ImportError: except ImportError:
@@ -70,12 +69,12 @@ except ImportError:
# Validation # Validation
from .validation import ( from .validation import (
validate_file_payload,
validate_image,
validate_message,
validate_passphrase, validate_passphrase,
validate_pin, validate_pin,
validate_rsa_key, validate_rsa_key,
validate_message,
validate_file_payload,
validate_image,
validate_security_factors, validate_security_factors,
) )
@@ -84,62 +83,61 @@ validate_reference_photo = validate_image
validate_carrier = validate_image validate_carrier = validate_image
# Additional validators # Additional validators
from .validation import ( # Constants
validate_embed_mode, from .constants import (
validate_dct_output_format, DEFAULT_PASSPHRASE_WORDS,
validate_dct_color_mode, EMBED_MODE_AUTO,
) EMBED_MODE_DCT,
EMBED_MODE_LSB,
# Models FORMAT_VERSION,
from .models import ( LOSSLESS_FORMATS,
ImageInfo, MAX_IMAGE_PIXELS,
CapacityComparison, MAX_MESSAGE_SIZE,
GenerateResult, MAX_PASSPHRASE_WORDS,
EncodeResult, MAX_PIN_LENGTH,
DecodeResult, MIN_IMAGE_PIXELS,
FilePayload, MIN_PASSPHRASE_WORDS,
Credentials, MIN_PIN_LENGTH,
ValidationResult, RECOMMENDED_PASSPHRASE_WORDS,
) )
# Exceptions # Exceptions
from .exceptions import ( from .exceptions import (
StegasooError, CapacityError,
ValidationError,
PinValidationError,
MessageValidationError,
ImageValidationError,
KeyValidationError,
SecurityFactorError,
CryptoError, CryptoError,
EncryptionError,
DecryptionError, DecryptionError,
EmbeddingError,
EncryptionError,
ExtractionError,
ImageValidationError,
InvalidHeaderError,
KeyDerivationError, KeyDerivationError,
KeyGenerationError, KeyGenerationError,
KeyPasswordError, KeyPasswordError,
KeyValidationError,
MessageValidationError,
PinValidationError,
SecurityFactorError,
SteganographyError, SteganographyError,
CapacityError, StegasooError,
ExtractionError, ValidationError,
EmbeddingError,
InvalidHeaderError,
) )
# Constants # Models
from .constants import ( from .models import (
FORMAT_VERSION, CapacityComparison,
MIN_PASSPHRASE_WORDS, Credentials,
RECOMMENDED_PASSPHRASE_WORDS, DecodeResult,
DEFAULT_PASSPHRASE_WORDS, EncodeResult,
MAX_PASSPHRASE_WORDS, FilePayload,
MIN_PIN_LENGTH, GenerateResult,
MAX_PIN_LENGTH, ImageInfo,
MAX_MESSAGE_SIZE, ValidationResult,
MIN_IMAGE_PIXELS, )
MAX_IMAGE_PIXELS, from .validation import (
LOSSLESS_FORMATS, validate_dct_color_mode,
EMBED_MODE_LSB, validate_dct_output_format,
EMBED_MODE_DCT, validate_embed_mode,
EMBED_MODE_AUTO,
) )
# Aliases for backward compatibility # Aliases for backward compatibility
@@ -159,7 +157,7 @@ __all__ = [
"decode", "decode",
"decode_file", "decode_file",
"decode_text", "decode_text",
# Generation # Generation
"generate_pin", "generate_pin",
"generate_passphrase", "generate_passphrase",
@@ -167,7 +165,7 @@ __all__ = [
"generate_credentials", "generate_credentials",
"export_rsa_key_pem", "export_rsa_key_pem",
"load_rsa_key", "load_rsa_key",
# Channel key management (v4.0.0) # Channel key management (v4.0.0)
"generate_channel_key", "generate_channel_key",
"get_channel_key", "get_channel_key",
@@ -179,28 +177,28 @@ __all__ = [
"format_channel_key", "format_channel_key",
"get_active_channel_key", "get_active_channel_key",
"get_channel_fingerprint", "get_channel_fingerprint",
# Image utilities # Image utilities
"get_image_info", "get_image_info",
"compare_capacity", "compare_capacity",
# Utilities # Utilities
"generate_filename", "generate_filename",
# Crypto # Crypto
"has_argon2", "has_argon2",
# Steganography # Steganography
"has_dct_support", "has_dct_support",
"compare_modes", "compare_modes",
"will_fit_by_mode", "will_fit_by_mode",
# QR utilities # QR utilities
"generate_qr_code", "generate_qr_code",
"extract_key_from_qr", "extract_key_from_qr",
"detect_and_crop_qr", "detect_and_crop_qr",
"HAS_QR_UTILS", "HAS_QR_UTILS",
# Validation # Validation
"validate_reference_photo", "validate_reference_photo",
"validate_carrier", "validate_carrier",
@@ -214,7 +212,7 @@ __all__ = [
"validate_dct_output_format", "validate_dct_output_format",
"validate_dct_color_mode", "validate_dct_color_mode",
"validate_channel_key", "validate_channel_key",
# Models # Models
"ImageInfo", "ImageInfo",
"CapacityComparison", "CapacityComparison",
@@ -224,7 +222,7 @@ __all__ = [
"FilePayload", "FilePayload",
"Credentials", "Credentials",
"ValidationResult", "ValidationResult",
# Exceptions # Exceptions
"StegasooError", "StegasooError",
"ValidationError", "ValidationError",
@@ -244,7 +242,7 @@ __all__ = [
"ExtractionError", "ExtractionError",
"EmbeddingError", "EmbeddingError",
"InvalidHeaderError", "InvalidHeaderError",
# Constants # Constants
"FORMAT_VERSION", "FORMAT_VERSION",
"MIN_PASSPHRASE_WORDS", "MIN_PASSPHRASE_WORDS",

View File

@@ -9,15 +9,14 @@ Changes in v3.2.0:
- Updated all credential handling to use v3.2.0 API - Updated all credential handling to use v3.2.0 API
""" """
import os
import json import json
import time
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Optional, Callable, Iterator
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading import threading
import time
from collections.abc import Callable, Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from .constants import ALLOWED_IMAGE_EXTENSIONS, LOSSLESS_FORMATS from .constants import ALLOWED_IMAGE_EXTENSIONS, LOSSLESS_FORMATS
@@ -35,22 +34,22 @@ class BatchStatus(Enum):
class BatchItem: class BatchItem:
"""Represents a single item in a batch operation.""" """Represents a single item in a batch operation."""
input_path: Path input_path: Path
output_path: Optional[Path] = None output_path: Path | None = None
status: BatchStatus = BatchStatus.PENDING status: BatchStatus = BatchStatus.PENDING
error: Optional[str] = None error: str | None = None
start_time: Optional[float] = None start_time: float | None = None
end_time: Optional[float] = None end_time: float | None = None
input_size: int = 0 input_size: int = 0
output_size: int = 0 output_size: int = 0
message: str = "" message: str = ""
@property @property
def duration(self) -> Optional[float]: def duration(self) -> float | None:
"""Processing duration in seconds.""" """Processing duration in seconds."""
if self.start_time and self.end_time: if self.start_time and self.end_time:
return self.end_time - self.start_time return self.end_time - self.start_time
return None return None
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization.""" """Convert to dictionary for JSON serialization."""
return { return {
@@ -69,14 +68,14 @@ class BatchItem:
class BatchCredentials: class BatchCredentials:
""" """
Credentials for batch encode/decode operations (v3.2.0). Credentials for batch encode/decode operations (v3.2.0).
Provides a structured way to pass authentication factors Provides a structured way to pass authentication factors
for batch processing instead of using plain dicts. for batch processing instead of using plain dicts.
Changes in v3.2.0: Changes in v3.2.0:
- Renamed day_phrase → passphrase - Renamed day_phrase → passphrase
- Removed date_str (no longer used in cryptographic operations) - Removed date_str (no longer used in cryptographic operations)
Example: Example:
creds = BatchCredentials( creds = BatchCredentials(
reference_photo=ref_bytes, reference_photo=ref_bytes,
@@ -88,9 +87,9 @@ class BatchCredentials:
reference_photo: bytes reference_photo: bytes
passphrase: str # v3.2.0: renamed from day_phrase passphrase: str # v3.2.0: renamed from day_phrase
pin: str = "" pin: str = ""
rsa_key_data: Optional[bytes] = None rsa_key_data: bytes | None = None
rsa_password: Optional[str] = None rsa_password: str | None = None
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""Convert to dictionary for API compatibility.""" """Convert to dictionary for API compatibility."""
return { return {
@@ -100,17 +99,17 @@ class BatchCredentials:
"rsa_key_data": self.rsa_key_data, "rsa_key_data": self.rsa_key_data,
"rsa_password": self.rsa_password, "rsa_password": self.rsa_password,
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'BatchCredentials': def from_dict(cls, data: dict) -> 'BatchCredentials':
""" """
Create BatchCredentials from a dictionary. Create BatchCredentials from a dictionary.
Handles both v3.2.0 format (passphrase) and legacy format (day_phrase). Handles both v3.2.0 format (passphrase) and legacy format (day_phrase).
""" """
# Handle legacy 'day_phrase' key # Handle legacy 'day_phrase' key
passphrase = data.get('passphrase') or data.get('day_phrase', '') passphrase = data.get('passphrase') or data.get('day_phrase', '')
return cls( return cls(
reference_photo=data['reference_photo'], reference_photo=data['reference_photo'],
passphrase=passphrase, passphrase=passphrase,
@@ -129,16 +128,16 @@ class BatchResult:
failed: int = 0 failed: int = 0
skipped: int = 0 skipped: int = 0
start_time: float = field(default_factory=time.time) start_time: float = field(default_factory=time.time)
end_time: Optional[float] = None end_time: float | None = None
items: list[BatchItem] = field(default_factory=list) items: list[BatchItem] = field(default_factory=list)
@property @property
def duration(self) -> Optional[float]: def duration(self) -> float | None:
"""Total batch duration in seconds.""" """Total batch duration in seconds."""
if self.end_time: if self.end_time:
return self.end_time - self.start_time return self.end_time - self.start_time
return None return None
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization.""" """Convert to dictionary for JSON serialization."""
return { return {
@@ -152,7 +151,7 @@ class BatchResult:
}, },
"items": [item.to_dict() for item in self.items], "items": [item.to_dict() for item in self.items],
} }
def to_json(self, indent: int = 2) -> str: def to_json(self, indent: int = 2) -> str:
"""Serialize to JSON string.""" """Serialize to JSON string."""
return json.dumps(self.to_dict(), indent=indent) return json.dumps(self.to_dict(), indent=indent)
@@ -165,10 +164,10 @@ ProgressCallback = Callable[[int, int, BatchItem], None]
class BatchProcessor: class BatchProcessor:
""" """
Handles batch encoding/decoding operations (v3.2.0). Handles batch encoding/decoding operations (v3.2.0).
Usage: Usage:
processor = BatchProcessor(max_workers=4) processor = BatchProcessor(max_workers=4)
# Batch encode with BatchCredentials # Batch encode with BatchCredentials
creds = BatchCredentials( creds = BatchCredentials(
reference_photo=ref_bytes, reference_photo=ref_bytes,
@@ -181,7 +180,7 @@ class BatchProcessor:
output_dir="./encoded/", output_dir="./encoded/",
credentials=creds, credentials=creds,
) )
# Batch encode with dict credentials # Batch encode with dict credentials
result = processor.batch_encode( result = processor.batch_encode(
images=['img1.png', 'img2.png'], images=['img1.png', 'img2.png'],
@@ -192,24 +191,24 @@ class BatchProcessor:
"pin": "123456" "pin": "123456"
}, },
) )
# Batch decode # Batch decode
result = processor.batch_decode( result = processor.batch_decode(
images=['encoded1.png', 'encoded2.png'], images=['encoded1.png', 'encoded2.png'],
credentials=creds, credentials=creds,
) )
""" """
def __init__(self, max_workers: int = 4): def __init__(self, max_workers: int = 4):
""" """
Initialize batch processor. Initialize batch processor.
Args: Args:
max_workers: Maximum parallel workers (default 4) max_workers: Maximum parallel workers (default 4)
""" """
self.max_workers = max_workers self.max_workers = max_workers
self._lock = threading.Lock() self._lock = threading.Lock()
def find_images( def find_images(
self, self,
paths: list[str | Path], paths: list[str | Path],
@@ -217,67 +216,67 @@ class BatchProcessor:
) -> Iterator[Path]: ) -> Iterator[Path]:
""" """
Find all valid image files from paths. Find all valid image files from paths.
Args: Args:
paths: List of files or directories paths: List of files or directories
recursive: Search directories recursively recursive: Search directories recursively
Yields: Yields:
Path objects for each valid image Path objects for each valid image
""" """
for path in paths: for path in paths:
path = Path(path) path = Path(path)
if path.is_file(): if path.is_file():
if self._is_valid_image(path): if self._is_valid_image(path):
yield path yield path
elif path.is_dir(): elif path.is_dir():
pattern = '**/*' if recursive else '*' pattern = '**/*' if recursive else '*'
for file_path in path.glob(pattern): for file_path in path.glob(pattern):
if file_path.is_file() and self._is_valid_image(file_path): if file_path.is_file() and self._is_valid_image(file_path):
yield file_path yield file_path
def _is_valid_image(self, path: Path) -> bool: def _is_valid_image(self, path: Path) -> bool:
"""Check if path is a valid image file.""" """Check if path is a valid image file."""
return path.suffix.lower().lstrip('.') in ALLOWED_IMAGE_EXTENSIONS return path.suffix.lower().lstrip('.') in ALLOWED_IMAGE_EXTENSIONS
def _normalize_credentials( def _normalize_credentials(
self, self,
credentials: dict | BatchCredentials | None credentials: dict | BatchCredentials | None
) -> BatchCredentials: ) -> BatchCredentials:
""" """
Normalize credentials to BatchCredentials object. Normalize credentials to BatchCredentials object.
Handles both dict and BatchCredentials input, and legacy 'day_phrase' key. Handles both dict and BatchCredentials input, and legacy 'day_phrase' key.
""" """
if credentials is None: if credentials is None:
raise ValueError("Credentials are required") raise ValueError("Credentials are required")
if isinstance(credentials, BatchCredentials): if isinstance(credentials, BatchCredentials):
return credentials return credentials
if isinstance(credentials, dict): if isinstance(credentials, dict):
return BatchCredentials.from_dict(credentials) return BatchCredentials.from_dict(credentials)
raise ValueError(f"Invalid credentials type: {type(credentials)}") raise ValueError(f"Invalid credentials type: {type(credentials)}")
def batch_encode( def batch_encode(
self, self,
images: list[str | Path], images: list[str | Path],
message: Optional[str] = None, message: str | None = None,
file_payload: Optional[Path] = None, file_payload: Path | None = None,
output_dir: Optional[Path] = None, output_dir: Path | None = None,
output_suffix: str = "_encoded", output_suffix: str = "_encoded",
credentials: dict | BatchCredentials | None = None, credentials: dict | BatchCredentials | None = None,
compress: bool = True, compress: bool = True,
recursive: bool = False, recursive: bool = False,
progress_callback: Optional[ProgressCallback] = None, progress_callback: ProgressCallback | None = None,
encode_func: Callable = None, encode_func: Callable = None,
) -> BatchResult: ) -> BatchResult:
""" """
Encode message into multiple images. Encode message into multiple images.
Args: Args:
images: List of image paths or directories images: List of image paths or directories
message: Text message to encode (mutually exclusive with file_payload) message: Text message to encode (mutually exclusive with file_payload)
@@ -289,43 +288,43 @@ class BatchProcessor:
recursive: Search directories recursively recursive: Search directories recursively
progress_callback: Called for each item: callback(current, total, item) progress_callback: Called for each item: callback(current, total, item)
encode_func: Custom encode function (for integration) encode_func: Custom encode function (for integration)
Returns: Returns:
BatchResult with operation summary BatchResult with operation summary
""" """
if message is None and file_payload is None: if message is None and file_payload is None:
raise ValueError("Either message or file_payload must be provided") raise ValueError("Either message or file_payload must be provided")
# Normalize credentials to BatchCredentials # Normalize credentials to BatchCredentials
creds = self._normalize_credentials(credentials) creds = self._normalize_credentials(credentials)
result = BatchResult(operation="encode") result = BatchResult(operation="encode")
image_paths = list(self.find_images(images, recursive)) image_paths = list(self.find_images(images, recursive))
result.total = len(image_paths) result.total = len(image_paths)
if output_dir: if output_dir:
output_dir = Path(output_dir) output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Prepare batch items # Prepare batch items
for img_path in image_paths: for img_path in image_paths:
if output_dir: if output_dir:
out_path = output_dir / f"{img_path.stem}{output_suffix}.png" out_path = output_dir / f"{img_path.stem}{output_suffix}.png"
else: else:
out_path = img_path.parent / f"{img_path.stem}{output_suffix}.png" out_path = img_path.parent / f"{img_path.stem}{output_suffix}.png"
item = BatchItem( item = BatchItem(
input_path=img_path, input_path=img_path,
output_path=out_path, output_path=out_path,
input_size=img_path.stat().st_size if img_path.exists() else 0, input_size=img_path.stat().st_size if img_path.exists() else 0,
) )
result.items.append(item) result.items.append(item)
# Process items # Process items
def process_encode(item: BatchItem) -> BatchItem: def process_encode(item: BatchItem) -> BatchItem:
item.status = BatchStatus.PROCESSING item.status = BatchStatus.PROCESSING
item.start_time = time.time() item.start_time = time.time()
try: try:
if encode_func: if encode_func:
# Use provided encode function # Use provided encode function
@@ -340,35 +339,35 @@ class BatchProcessor:
else: else:
# Use stegasoo encode # Use stegasoo encode
self._do_encode(item, message, file_payload, creds, compress) self._do_encode(item, message, file_payload, creds, compress)
item.status = BatchStatus.SUCCESS item.status = BatchStatus.SUCCESS
item.output_size = item.output_path.stat().st_size if item.output_path and item.output_path.exists() else 0 item.output_size = item.output_path.stat().st_size if item.output_path and item.output_path.exists() else 0
item.message = f"Encoded to {item.output_path.name}" item.message = f"Encoded to {item.output_path.name}"
except Exception as e: except Exception as e:
item.status = BatchStatus.FAILED item.status = BatchStatus.FAILED
item.error = str(e) item.error = str(e)
item.end_time = time.time() item.end_time = time.time()
return item return item
# Execute with thread pool # Execute with thread pool
self._execute_batch(result, process_encode, progress_callback) self._execute_batch(result, process_encode, progress_callback)
return result return result
def batch_decode( def batch_decode(
self, self,
images: list[str | Path], images: list[str | Path],
output_dir: Optional[Path] = None, output_dir: Path | None = None,
credentials: dict | BatchCredentials | None = None, credentials: dict | BatchCredentials | None = None,
recursive: bool = False, recursive: bool = False,
progress_callback: Optional[ProgressCallback] = None, progress_callback: ProgressCallback | None = None,
decode_func: Callable = None, decode_func: Callable = None,
) -> BatchResult: ) -> BatchResult:
""" """
Decode messages from multiple images. Decode messages from multiple images.
Args: Args:
images: List of image paths or directories images: List of image paths or directories
output_dir: Output directory for file payloads (default: same as input) output_dir: Output directory for file payloads (default: same as input)
@@ -376,21 +375,21 @@ class BatchProcessor:
recursive: Search directories recursively recursive: Search directories recursively
progress_callback: Called for each item: callback(current, total, item) progress_callback: Called for each item: callback(current, total, item)
decode_func: Custom decode function (for integration) decode_func: Custom decode function (for integration)
Returns: Returns:
BatchResult with decoded messages in item.message fields BatchResult with decoded messages in item.message fields
""" """
# Normalize credentials to BatchCredentials # Normalize credentials to BatchCredentials
creds = self._normalize_credentials(credentials) creds = self._normalize_credentials(credentials)
result = BatchResult(operation="decode") result = BatchResult(operation="decode")
image_paths = list(self.find_images(images, recursive)) image_paths = list(self.find_images(images, recursive))
result.total = len(image_paths) result.total = len(image_paths)
if output_dir: if output_dir:
output_dir = Path(output_dir) output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Prepare batch items # Prepare batch items
for img_path in image_paths: for img_path in image_paths:
item = BatchItem( item = BatchItem(
@@ -399,12 +398,12 @@ class BatchProcessor:
input_size=img_path.stat().st_size if img_path.exists() else 0, input_size=img_path.stat().st_size if img_path.exists() else 0,
) )
result.items.append(item) result.items.append(item)
# Process items # Process items
def process_decode(item: BatchItem) -> BatchItem: def process_decode(item: BatchItem) -> BatchItem:
item.status = BatchStatus.PROCESSING item.status = BatchStatus.PROCESSING
item.start_time = time.time() item.start_time = time.time()
try: try:
if decode_func: if decode_func:
# Use provided decode function # Use provided decode function
@@ -417,40 +416,40 @@ class BatchProcessor:
else: else:
# Use stegasoo decode # Use stegasoo decode
item.message = self._do_decode(item, creds) item.message = self._do_decode(item, creds)
item.status = BatchStatus.SUCCESS item.status = BatchStatus.SUCCESS
except Exception as e: except Exception as e:
item.status = BatchStatus.FAILED item.status = BatchStatus.FAILED
item.error = str(e) item.error = str(e)
item.end_time = time.time() item.end_time = time.time()
return item return item
# Execute with thread pool # Execute with thread pool
self._execute_batch(result, process_decode, progress_callback) self._execute_batch(result, process_decode, progress_callback)
return result return result
def _execute_batch( def _execute_batch(
self, self,
result: BatchResult, result: BatchResult,
process_func: Callable[[BatchItem], BatchItem], process_func: Callable[[BatchItem], BatchItem],
progress_callback: Optional[ProgressCallback] = None, progress_callback: ProgressCallback | None = None,
) -> None: ) -> None:
"""Execute batch processing with thread pool.""" """Execute batch processing with thread pool."""
completed = 0 completed = 0
with ThreadPoolExecutor(max_workers=self.max_workers) as executor: with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = { futures = {
executor.submit(process_func, item): item executor.submit(process_func, item): item
for item in result.items for item in result.items
} }
for future in as_completed(futures): for future in as_completed(futures):
item = future.result() item = future.result()
completed += 1 completed += 1
with self._lock: with self._lock:
if item.status == BatchStatus.SUCCESS: if item.status == BatchStatus.SUCCESS:
result.succeeded += 1 result.succeeded += 1
@@ -458,32 +457,32 @@ class BatchProcessor:
result.failed += 1 result.failed += 1
elif item.status == BatchStatus.SKIPPED: elif item.status == BatchStatus.SKIPPED:
result.skipped += 1 result.skipped += 1
if progress_callback: if progress_callback:
progress_callback(completed, result.total, item) progress_callback(completed, result.total, item)
result.end_time = time.time() result.end_time = time.time()
def _do_encode( def _do_encode(
self, self,
item: BatchItem, item: BatchItem,
message: Optional[str], message: str | None,
file_payload: Optional[Path], file_payload: Path | None,
creds: BatchCredentials, creds: BatchCredentials,
compress: bool compress: bool
) -> None: ) -> None:
""" """
Perform actual encoding using stegasoo.encode. Perform actual encoding using stegasoo.encode.
Override this method to customize encoding behavior. Override this method to customize encoding behavior.
""" """
try: try:
from .encode import encode, encode_file from .encode import encode
from .models import FilePayload from .models import FilePayload
# Read carrier image # Read carrier image
carrier_image = item.input_path.read_bytes() carrier_image = item.input_path.read_bytes()
if file_payload: if file_payload:
# Encode file # Encode file
payload = FilePayload.from_file(str(file_payload)) payload = FilePayload.from_file(str(file_payload))
@@ -507,15 +506,15 @@ class BatchProcessor:
rsa_key_data=creds.rsa_key_data, rsa_key_data=creds.rsa_key_data,
rsa_password=creds.rsa_password, rsa_password=creds.rsa_password,
) )
# Write output # Write output
if item.output_path: if item.output_path:
item.output_path.write_bytes(result.stego_image) item.output_path.write_bytes(result.stego_image)
except ImportError: except ImportError:
# Fallback to mock if stegasoo.encode not available # Fallback to mock if stegasoo.encode not available
self._mock_encode(item, message, creds, compress) self._mock_encode(item, message, creds, compress)
def _do_decode( def _do_decode(
self, self,
item: BatchItem, item: BatchItem,
@@ -523,15 +522,15 @@ class BatchProcessor:
) -> str: ) -> str:
""" """
Perform actual decoding using stegasoo.decode. Perform actual decoding using stegasoo.decode.
Override this method to customize decoding behavior. Override this method to customize decoding behavior.
""" """
try: try:
from .decode import decode from .decode import decode
# Read stego image # Read stego image
stego_image = item.input_path.read_bytes() stego_image = item.input_path.read_bytes()
result = decode( result = decode(
stego_image=stego_image, stego_image=stego_image,
reference_photo=creds.reference_photo, reference_photo=creds.reference_photo,
@@ -540,7 +539,7 @@ class BatchProcessor:
rsa_key_data=creds.rsa_key_data, rsa_key_data=creds.rsa_key_data,
rsa_password=creds.rsa_password, rsa_password=creds.rsa_password,
) )
if result.is_text: if result.is_text:
return result.message or "" return result.message or ""
else: else:
@@ -550,11 +549,11 @@ class BatchProcessor:
output_file.write_bytes(result.file_data) output_file.write_bytes(result.file_data)
return f"File extracted: {result.filename or 'extracted_file'}" return f"File extracted: {result.filename or 'extracted_file'}"
return f"[File: {result.filename or 'binary data'}]" return f"[File: {result.filename or 'binary data'}]"
except ImportError: except ImportError:
# Fallback to mock if stegasoo.decode not available # Fallback to mock if stegasoo.decode not available
return self._mock_decode(item, creds) return self._mock_decode(item, creds)
def _mock_encode( def _mock_encode(
self, self,
item: BatchItem, item: BatchItem,
@@ -568,7 +567,7 @@ class BatchProcessor:
import shutil import shutil
if item.output_path: if item.output_path:
shutil.copy(item.input_path, item.output_path) shutil.copy(item.input_path, item.output_path)
def _mock_decode(self, item: BatchItem, creds: BatchCredentials) -> str: def _mock_decode(self, item: BatchItem, creds: BatchCredentials) -> str:
"""Mock decode for testing - replace with actual stego.decode()""" """Mock decode for testing - replace with actual stego.decode()"""
# This is a placeholder - in real usage, you'd call your actual decode function # This is a placeholder - in real usage, you'd call your actual decode function
@@ -581,30 +580,31 @@ def batch_capacity_check(
) -> list[dict]: ) -> list[dict]:
""" """
Check capacity of multiple images without encoding. Check capacity of multiple images without encoding.
Args: Args:
images: List of image paths or directories images: List of image paths or directories
recursive: Search directories recursively recursive: Search directories recursively
Returns: Returns:
List of dicts with path, dimensions, and estimated capacity List of dicts with path, dimensions, and estimated capacity
""" """
from PIL import Image from PIL import Image
from .constants import MAX_IMAGE_PIXELS from .constants import MAX_IMAGE_PIXELS
processor = BatchProcessor() processor = BatchProcessor()
results = [] results = []
for img_path in processor.find_images(images, recursive): for img_path in processor.find_images(images, recursive):
try: try:
with Image.open(img_path) as img: with Image.open(img_path) as img:
width, height = img.size width, height = img.size
pixels = width * height pixels = width * height
# Estimate: 3 bits per pixel (RGB LSB), minus header overhead # Estimate: 3 bits per pixel (RGB LSB), minus header overhead
capacity_bits = pixels * 3 capacity_bits = pixels * 3
capacity_bytes = (capacity_bits // 8) - 100 # Header overhead capacity_bytes = (capacity_bits // 8) - 100 # Header overhead
results.append({ results.append({
"path": str(img_path), "path": str(img_path),
"dimensions": f"{width}x{height}", "dimensions": f"{width}x{height}",
@@ -622,25 +622,25 @@ def batch_capacity_check(
"error": str(e), "error": str(e),
"valid": False, "valid": False,
}) })
return results return results
def _get_image_warnings(img, path: Path) -> list[str]: def _get_image_warnings(img, path: Path) -> list[str]:
"""Generate warnings for an image.""" """Generate warnings for an image."""
from .constants import MAX_IMAGE_PIXELS, LOSSLESS_FORMATS from .constants import LOSSLESS_FORMATS, MAX_IMAGE_PIXELS
warnings = [] warnings = []
if img.format not in LOSSLESS_FORMATS: if img.format not in LOSSLESS_FORMATS:
warnings.append(f"Lossy format ({img.format}) - quality will degrade on re-save") warnings.append(f"Lossy format ({img.format}) - quality will degrade on re-save")
if img.size[0] * img.size[1] > MAX_IMAGE_PIXELS: if img.size[0] * img.size[1] > MAX_IMAGE_PIXELS:
warnings.append(f"Image exceeds {MAX_IMAGE_PIXELS:,} pixel limit") warnings.append(f"Image exceeds {MAX_IMAGE_PIXELS:,} pixel limit")
if img.mode not in ('RGB', 'RGBA'): if img.mode not in ('RGB', 'RGBA'):
warnings.append(f"Non-RGB mode ({img.mode}) - will be converted") warnings.append(f"Non-RGB mode ({img.mode}) - will be converted")
return warnings return warnings
@@ -657,7 +657,7 @@ def print_batch_result(result: BatchResult, verbose: bool = False) -> None:
print(f"Skipped: {result.skipped}") print(f"Skipped: {result.skipped}")
if result.duration: if result.duration:
print(f"Duration: {result.duration:.2f}s") print(f"Duration: {result.duration:.2f}s")
if verbose or result.failed > 0: if verbose or result.failed > 0:
print(f"\n{''*60}") print(f"\n{''*60}")
for item in result.items: for item in result.items:
@@ -668,7 +668,7 @@ def print_batch_result(result: BatchResult, verbose: bool = False) -> None:
BatchStatus.PENDING: "", BatchStatus.PENDING: "",
BatchStatus.PROCESSING: "", BatchStatus.PROCESSING: "",
}.get(item.status, "?") }.get(item.status, "?")
print(f"{status_icon} {item.input_path.name}") print(f"{status_icon} {item.input_path.name}")
if item.error: if item.error:
print(f" Error: {item.error}") print(f" Error: {item.error}")

View File

@@ -24,12 +24,11 @@ INTEGRATION STATUS (v4.0.0):
- ✅ Helpful error messages for channel key mismatches - ✅ Helpful error messages for channel key mismatches
""" """
import os
import secrets
import hashlib import hashlib
import os
import re import re
import secrets
from pathlib import Path from pathlib import Path
from typing import Optional, List
from .debug import debug from .debug import debug
@@ -52,10 +51,10 @@ CONFIG_LOCATIONS = [
def generate_channel_key() -> str: def generate_channel_key() -> str:
""" """
Generate a new random channel key. Generate a new random channel key.
Returns: Returns:
Formatted channel key (e.g., "ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456") Formatted channel key (e.g., "ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456")
Example: Example:
>>> key = generate_channel_key() >>> key = generate_channel_key()
>>> len(key) >>> len(key)
@@ -64,7 +63,7 @@ def generate_channel_key() -> str:
# Generate 32 random alphanumeric characters # Generate 32 random alphanumeric characters
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
raw_key = ''.join(secrets.choice(alphabet) for _ in range(CHANNEL_KEY_LENGTH)) raw_key = ''.join(secrets.choice(alphabet) for _ in range(CHANNEL_KEY_LENGTH))
formatted = format_channel_key(raw_key) formatted = format_channel_key(raw_key)
debug.print(f"Generated channel key: {get_channel_fingerprint(formatted)}") debug.print(f"Generated channel key: {get_channel_fingerprint(formatted)}")
return formatted return formatted
@@ -73,32 +72,32 @@ def generate_channel_key() -> str:
def format_channel_key(raw_key: str) -> str: def format_channel_key(raw_key: str) -> str:
""" """
Format a raw key string into the standard format. Format a raw key string into the standard format.
Args: Args:
raw_key: Raw key string (with or without dashes) raw_key: Raw key string (with or without dashes)
Returns: Returns:
Formatted key with dashes (XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX) Formatted key with dashes (XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX)
Raises: Raises:
ValueError: If key is invalid length or contains invalid characters ValueError: If key is invalid length or contains invalid characters
Example: Example:
>>> format_channel_key("ABCD1234EFGH5678IJKL9012MNOP3456") >>> format_channel_key("ABCD1234EFGH5678IJKL9012MNOP3456")
"ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456" "ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456"
""" """
# Remove any existing dashes, spaces, and convert to uppercase # Remove any existing dashes, spaces, and convert to uppercase
clean = raw_key.replace('-', '').replace(' ', '').upper() clean = raw_key.replace('-', '').replace(' ', '').upper()
if len(clean) != CHANNEL_KEY_LENGTH: if len(clean) != CHANNEL_KEY_LENGTH:
raise ValueError( raise ValueError(
f"Channel key must be {CHANNEL_KEY_LENGTH} characters (got {len(clean)})" f"Channel key must be {CHANNEL_KEY_LENGTH} characters (got {len(clean)})"
) )
# Validate characters # Validate characters
if not all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' for c in clean): if not all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' for c in clean):
raise ValueError("Channel key must contain only letters A-Z and digits 0-9") raise ValueError("Channel key must contain only letters A-Z and digits 0-9")
# Format with dashes every 4 characters # Format with dashes every 4 characters
return '-'.join(clean[i:i+4] for i in range(0, CHANNEL_KEY_LENGTH, 4)) return '-'.join(clean[i:i+4] for i in range(0, CHANNEL_KEY_LENGTH, 4))
@@ -106,13 +105,13 @@ def format_channel_key(raw_key: str) -> str:
def validate_channel_key(key: str) -> bool: def validate_channel_key(key: str) -> bool:
""" """
Validate a channel key format. Validate a channel key format.
Args: Args:
key: Channel key to validate key: Channel key to validate
Returns: Returns:
True if valid format, False otherwise True if valid format, False otherwise
Example: Example:
>>> validate_channel_key("ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456") >>> validate_channel_key("ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456")
True True
@@ -121,7 +120,7 @@ def validate_channel_key(key: str) -> bool:
""" """
if not key: if not key:
return False return False
try: try:
formatted = format_channel_key(key) formatted = format_channel_key(key)
return bool(CHANNEL_KEY_PATTERN.match(formatted)) return bool(CHANNEL_KEY_PATTERN.match(formatted))
@@ -129,18 +128,18 @@ def validate_channel_key(key: str) -> bool:
return False return False
def get_channel_key() -> Optional[str]: def get_channel_key() -> str | None:
""" """
Get the current channel key from environment or config. Get the current channel key from environment or config.
Checks in order: Checks in order:
1. STEGASOO_CHANNEL_KEY environment variable 1. STEGASOO_CHANNEL_KEY environment variable
2. ./config/channel.key file 2. ./config/channel.key file
3. ~/.stegasoo/channel.key file 3. ~/.stegasoo/channel.key file
Returns: Returns:
Channel key if configured, None if in public mode Channel key if configured, None if in public mode
Example: Example:
>>> key = get_channel_key() >>> key = get_channel_key()
>>> if key: >>> if key:
@@ -156,7 +155,7 @@ def get_channel_key() -> Optional[str]:
return format_channel_key(env_key) return format_channel_key(env_key)
else: else:
debug.print(f"Warning: Invalid {CHANNEL_KEY_ENV_VAR} format, ignoring") debug.print(f"Warning: Invalid {CHANNEL_KEY_ENV_VAR} format, ignoring")
# 2. Check config files # 2. Check config files
for config_path in CONFIG_LOCATIONS: for config_path in CONFIG_LOCATIONS:
if config_path.exists(): if config_path.exists():
@@ -165,10 +164,10 @@ def get_channel_key() -> Optional[str]:
if key and validate_channel_key(key): if key and validate_channel_key(key):
debug.print(f"Channel key from {config_path}: {get_channel_fingerprint(key)}") debug.print(f"Channel key from {config_path}: {get_channel_fingerprint(key)}")
return format_channel_key(key) return format_channel_key(key)
except (IOError, PermissionError) as e: except (OSError, PermissionError) as e:
debug.print(f"Could not read {config_path}: {e}") debug.print(f"Could not read {config_path}: {e}")
continue continue
# 3. No channel key configured (public mode) # 3. No channel key configured (public mode)
debug.print("No channel key configured (public mode)") debug.print("No channel key configured (public mode)")
return None return None
@@ -177,92 +176,92 @@ def get_channel_key() -> Optional[str]:
def set_channel_key(key: str, location: str = 'project') -> Path: def set_channel_key(key: str, location: str = 'project') -> Path:
""" """
Save a channel key to config file. Save a channel key to config file.
Args: Args:
key: Channel key to save (will be formatted) key: Channel key to save (will be formatted)
location: 'project' for ./config/ or 'user' for ~/.stegasoo/ location: 'project' for ./config/ or 'user' for ~/.stegasoo/
Returns: Returns:
Path where key was saved Path where key was saved
Raises: Raises:
ValueError: If key format is invalid ValueError: If key format is invalid
Example: Example:
>>> path = set_channel_key("ABCD1234EFGH5678IJKL9012MNOP3456") >>> path = set_channel_key("ABCD1234EFGH5678IJKL9012MNOP3456")
>>> print(path) >>> print(path)
./config/channel.key ./config/channel.key
""" """
formatted = format_channel_key(key) formatted = format_channel_key(key)
if location == 'user': if location == 'user':
config_path = Path.home() / '.stegasoo' / 'channel.key' config_path = Path.home() / '.stegasoo' / 'channel.key'
else: else:
config_path = Path('./config/channel.key') config_path = Path('./config/channel.key')
# Create directory if needed # Create directory if needed
config_path.parent.mkdir(parents=True, exist_ok=True) config_path.parent.mkdir(parents=True, exist_ok=True)
# Write key with newline # Write key with newline
config_path.write_text(formatted + '\n') config_path.write_text(formatted + '\n')
# Set restrictive permissions (owner read/write only) # Set restrictive permissions (owner read/write only)
try: try:
config_path.chmod(0o600) config_path.chmod(0o600)
except (OSError, AttributeError): except (OSError, AttributeError):
pass # Windows doesn't support chmod the same way pass # Windows doesn't support chmod the same way
debug.print(f"Channel key saved to {config_path}") debug.print(f"Channel key saved to {config_path}")
return config_path return config_path
def clear_channel_key(location: str = 'all') -> List[Path]: def clear_channel_key(location: str = 'all') -> list[Path]:
""" """
Remove channel key configuration. Remove channel key configuration.
Args: Args:
location: 'project', 'user', or 'all' location: 'project', 'user', or 'all'
Returns: Returns:
List of paths that were deleted List of paths that were deleted
Example: Example:
>>> deleted = clear_channel_key('all') >>> deleted = clear_channel_key('all')
>>> print(f"Removed {len(deleted)} files") >>> print(f"Removed {len(deleted)} files")
""" """
deleted = [] deleted = []
paths_to_check = [] paths_to_check = []
if location in ('project', 'all'): if location in ('project', 'all'):
paths_to_check.append(Path('./config/channel.key')) paths_to_check.append(Path('./config/channel.key'))
if location in ('user', 'all'): if location in ('user', 'all'):
paths_to_check.append(Path.home() / '.stegasoo' / 'channel.key') paths_to_check.append(Path.home() / '.stegasoo' / 'channel.key')
for path in paths_to_check: for path in paths_to_check:
if path.exists(): if path.exists():
try: try:
path.unlink() path.unlink()
deleted.append(path) deleted.append(path)
debug.print(f"Removed channel key: {path}") debug.print(f"Removed channel key: {path}")
except (IOError, PermissionError) as e: except (OSError, PermissionError) as e:
debug.print(f"Could not remove {path}: {e}") debug.print(f"Could not remove {path}: {e}")
return deleted return deleted
def get_channel_key_hash(key: Optional[str] = None) -> Optional[bytes]: def get_channel_key_hash(key: str | None = None) -> bytes | None:
""" """
Get the channel key as a 32-byte hash suitable for key derivation. Get the channel key as a 32-byte hash suitable for key derivation.
This hash is mixed into the Argon2 key derivation to bind This hash is mixed into the Argon2 key derivation to bind
encryption to a specific channel. encryption to a specific channel.
Args: Args:
key: Channel key (if None, reads from config) key: Channel key (if None, reads from config)
Returns: Returns:
32-byte SHA-256 hash of channel key, or None if no channel key 32-byte SHA-256 hash of channel key, or None if no channel key
Example: Example:
>>> hash_bytes = get_channel_key_hash() >>> hash_bytes = get_channel_key_hash()
>>> if hash_bytes: >>> if hash_bytes:
@@ -270,39 +269,39 @@ def get_channel_key_hash(key: Optional[str] = None) -> Optional[bytes]:
""" """
if key is None: if key is None:
key = get_channel_key() key = get_channel_key()
if not key: if not key:
return None return None
# Hash the formatted key to get consistent 32 bytes # Hash the formatted key to get consistent 32 bytes
formatted = format_channel_key(key) formatted = format_channel_key(key)
return hashlib.sha256(formatted.encode('utf-8')).digest() return hashlib.sha256(formatted.encode('utf-8')).digest()
def get_channel_fingerprint(key: Optional[str] = None) -> Optional[str]: def get_channel_fingerprint(key: str | None = None) -> str | None:
""" """
Get a short fingerprint for display purposes. Get a short fingerprint for display purposes.
Shows first and last 4 chars with masked middle. Shows first and last 4 chars with masked middle.
Args: Args:
key: Channel key (if None, reads from config) key: Channel key (if None, reads from config)
Returns: Returns:
Fingerprint like "ABCD-••••-••••-••••-••••-••••-••••-3456" or None Fingerprint like "ABCD-••••-••••-••••-••••-••••-••••-3456" or None
Example: Example:
>>> print(get_channel_fingerprint()) >>> print(get_channel_fingerprint())
ABCD-••••-••••-••••-••••-••••-••••-3456 ABCD-••••-••••-••••-••••-••••-••••-3456
""" """
if key is None: if key is None:
key = get_channel_key() key = get_channel_key()
if not key: if not key:
return None return None
formatted = format_channel_key(key) formatted = format_channel_key(key)
parts = formatted.split('-') parts = formatted.split('-')
# Show first and last group, mask the rest # Show first and last group, mask the rest
masked = [parts[0]] + ['••••'] * 6 + [parts[-1]] masked = [parts[0]] + ['••••'] * 6 + [parts[-1]]
return '-'.join(masked) return '-'.join(masked)
@@ -311,7 +310,7 @@ def get_channel_fingerprint(key: Optional[str] = None) -> Optional[str]:
def get_channel_status() -> dict: def get_channel_status() -> dict:
""" """
Get comprehensive channel key status. Get comprehensive channel key status.
Returns: Returns:
Dictionary with: Dictionary with:
- mode: 'private' or 'public' - mode: 'private' or 'public'
@@ -319,14 +318,14 @@ def get_channel_status() -> dict:
- fingerprint: masked key or None - fingerprint: masked key or None
- source: where key came from or None - source: where key came from or None
- key: full key (for export) or None - key: full key (for export) or None
Example: Example:
>>> status = get_channel_status() >>> status = get_channel_status()
>>> print(f"Mode: {status['mode']}") >>> print(f"Mode: {status['mode']}")
Mode: private Mode: private
""" """
key = get_channel_key() key = get_channel_key()
if key: if key:
# Find which source provided the key # Find which source provided the key
source = 'unknown' source = 'unknown'
@@ -341,9 +340,9 @@ def get_channel_status() -> dict:
if file_key and format_channel_key(file_key) == key: if file_key and format_channel_key(file_key) == key:
source = str(config_path) source = str(config_path)
break break
except (IOError, PermissionError): except (OSError, PermissionError):
continue continue
return { return {
'mode': 'private', 'mode': 'private',
'configured': True, 'configured': True,
@@ -364,10 +363,10 @@ def get_channel_status() -> dict:
def has_channel_key() -> bool: def has_channel_key() -> bool:
""" """
Quick check if a channel key is configured. Quick check if a channel key is configured.
Returns: Returns:
True if channel key is set, False for public mode True if channel key is set, False for public mode
Example: Example:
>>> if has_channel_key(): >>> if has_channel_key():
... print("Private channel active") ... print("Private channel active")
@@ -381,7 +380,7 @@ def has_channel_key() -> bool:
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
def print_status(): def print_status():
"""Print current channel status.""" """Print current channel status."""
status = get_channel_status() status = get_channel_status()
@@ -391,7 +390,7 @@ if __name__ == '__main__':
print(f"Source: {status['source']}") print(f"Source: {status['source']}")
else: else:
print("No channel key configured (public mode)") print("No channel key configured (public mode)")
if len(sys.argv) < 2: if len(sys.argv) < 2:
print("Channel Key Manager") print("Channel Key Manager")
print("=" * 40) print("=" * 40)
@@ -404,24 +403,24 @@ if __name__ == '__main__':
print(" python -m stegasoo.channel clear - Remove channel key") print(" python -m stegasoo.channel clear - Remove channel key")
print(" python -m stegasoo.channel status - Show status") print(" python -m stegasoo.channel status - Show status")
sys.exit(0) sys.exit(0)
cmd = sys.argv[1].lower() cmd = sys.argv[1].lower()
if cmd == 'generate': if cmd == 'generate':
key = generate_channel_key() key = generate_channel_key()
print(f"Generated channel key:") print("Generated channel key:")
print(f" {key}") print(f" {key}")
print() print()
save = input("Save to config? [y/N]: ").strip().lower() save = input("Save to config? [y/N]: ").strip().lower()
if save == 'y': if save == 'y':
path = set_channel_key(key) path = set_channel_key(key)
print(f"Saved to: {path}") print(f"Saved to: {path}")
elif cmd == 'set': elif cmd == 'set':
if len(sys.argv) < 3: if len(sys.argv) < 3:
print("Usage: python -m stegasoo.channel set <KEY>") print("Usage: python -m stegasoo.channel set <KEY>")
sys.exit(1) sys.exit(1)
try: try:
key = sys.argv[2] key = sys.argv[2]
formatted = format_channel_key(key) formatted = format_channel_key(key)
@@ -431,7 +430,7 @@ if __name__ == '__main__':
except ValueError as e: except ValueError as e:
print(f"Error: {e}") print(f"Error: {e}")
sys.exit(1) sys.exit(1)
elif cmd == 'show': elif cmd == 'show':
status = get_channel_status() status = get_channel_status()
if status['configured']: if status['configured']:
@@ -439,17 +438,17 @@ if __name__ == '__main__':
print(f"Source: {status['source']}") print(f"Source: {status['source']}")
else: else:
print("No channel key configured") print("No channel key configured")
elif cmd == 'clear': elif cmd == 'clear':
deleted = clear_channel_key('all') deleted = clear_channel_key('all')
if deleted: if deleted:
print(f"Removed channel key from: {', '.join(str(p) for p in deleted)}") print(f"Removed channel key from: {', '.join(str(p) for p in deleted)}")
else: else:
print("No channel key files found") print("No channel key files found")
elif cmd == 'status': elif cmd == 'status':
print_status() print_status()
else: else:
print(f"Unknown command: {cmd}") print(f"Unknown command: {cmd}")
sys.exit(1) sys.exit(1)

View File

@@ -8,33 +8,29 @@ Changes in v3.2.0:
- Updated help text to use 'passphrase' terminology - Updated help text to use 'passphrase' terminology
""" """
import sys
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional
import click import click
from .constants import (
__version__,
MAX_MESSAGE_SIZE,
MAX_FILE_PAYLOAD_SIZE,
DEFAULT_PIN_LENGTH,
DEFAULT_PASSPHRASE_WORDS, # v3.2.0: renamed from DEFAULT_PHRASE_WORDS
)
from .compression import (
CompressionAlgorithm,
get_available_algorithms,
algorithm_name,
HAS_LZ4,
)
from .batch import ( from .batch import (
BatchProcessor, BatchProcessor,
BatchResult,
batch_capacity_check, batch_capacity_check,
print_batch_result, print_batch_result,
) )
from .compression import (
HAS_LZ4,
CompressionAlgorithm,
algorithm_name,
get_available_algorithms,
)
from .constants import (
DEFAULT_PASSPHRASE_WORDS, # v3.2.0: renamed from DEFAULT_PHRASE_WORDS
DEFAULT_PIN_LENGTH,
MAX_FILE_PAYLOAD_SIZE,
MAX_MESSAGE_SIZE,
__version__,
)
# Click context settings # Click context settings
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@@ -47,7 +43,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
def cli(ctx, json_output): def cli(ctx, json_output):
""" """
Stegasoo - Steganography with hybrid authentication. Stegasoo - Steganography with hybrid authentication.
Hide messages in images using PIN + passphrase security. Hide messages in images using PIN + passphrase security.
""" """
ctx.ensure_object(dict) ctx.ensure_object(dict)
@@ -61,35 +57,35 @@ def cli(ctx, json_output):
@cli.command() @cli.command()
@click.argument('image', type=click.Path(exists=True)) @click.argument('image', type=click.Path(exists=True))
@click.option('-m', '--message', help='Message to encode') @click.option('-m', '--message', help='Message to encode')
@click.option('-f', '--file', 'file_payload', type=click.Path(exists=True), @click.option('-f', '--file', 'file_payload', type=click.Path(exists=True),
help='File to embed instead of message') help='File to embed instead of message')
@click.option('-o', '--output', type=click.Path(), help='Output image path') @click.option('-o', '--output', type=click.Path(), help='Output image path')
@click.option('--passphrase', prompt=True, hide_input=True, @click.option('--passphrase', prompt=True, hide_input=True,
confirmation_prompt=True, help='Passphrase (recommend 4+ words)') confirmation_prompt=True, help='Passphrase (recommend 4+ words)')
@click.option('--pin', prompt=True, hide_input=True, @click.option('--pin', prompt=True, hide_input=True,
confirmation_prompt=True, help='PIN code') confirmation_prompt=True, help='PIN code')
@click.option('--compress/--no-compress', default=True, @click.option('--compress/--no-compress', default=True,
help='Enable/disable compression (default: enabled)') help='Enable/disable compression (default: enabled)')
@click.option('--algorithm', type=click.Choice(['zlib', 'lz4', 'none']), @click.option('--algorithm', type=click.Choice(['zlib', 'lz4', 'none']),
default='zlib', help='Compression algorithm') default='zlib', help='Compression algorithm')
@click.option('--dry-run', is_flag=True, help='Show capacity usage without encoding') @click.option('--dry-run', is_flag=True, help='Show capacity usage without encoding')
@click.pass_context @click.pass_context
def encode(ctx, image, message, file_payload, output, passphrase, pin, def encode(ctx, image, message, file_payload, output, passphrase, pin,
compress, algorithm, dry_run): compress, algorithm, dry_run):
""" """
Encode a message or file into an image. Encode a message or file into an image.
Examples: Examples:
stegasoo encode photo.png -m "Secret message" --passphrase --pin stegasoo encode photo.png -m "Secret message" --passphrase --pin
stegasoo encode photo.png -f secret.pdf -o encoded.png stegasoo encode photo.png -f secret.pdf -o encoded.png
""" """
from PIL import Image from PIL import Image
if not message and not file_payload: if not message and not file_payload:
raise click.UsageError("Either --message or --file is required") raise click.UsageError("Either --message or --file is required")
# Parse compression algorithm # Parse compression algorithm
algo_map = { algo_map = {
'zlib': CompressionAlgorithm.ZLIB, 'zlib': CompressionAlgorithm.ZLIB,
@@ -97,11 +93,11 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
'none': CompressionAlgorithm.NONE, 'none': CompressionAlgorithm.NONE,
} }
compression_algo = algo_map[algorithm] if compress else CompressionAlgorithm.NONE compression_algo = algo_map[algorithm] if compress else CompressionAlgorithm.NONE
if algorithm == 'lz4' and not HAS_LZ4: if algorithm == 'lz4' and not HAS_LZ4:
click.echo("Warning: LZ4 not available, falling back to zlib", err=True) click.echo("Warning: LZ4 not available, falling back to zlib", err=True)
compression_algo = CompressionAlgorithm.ZLIB compression_algo = CompressionAlgorithm.ZLIB
# Calculate payload size # Calculate payload size
if file_payload: if file_payload:
payload_size = Path(file_payload).stat().st_size payload_size = Path(file_payload).stat().st_size
@@ -109,12 +105,12 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
else: else:
payload_size = len(message.encode('utf-8')) payload_size = len(message.encode('utf-8'))
payload_type = "text" payload_type = "text"
# Get image capacity # Get image capacity
with Image.open(image) as img: with Image.open(image) as img:
width, height = img.size width, height = img.size
capacity_bytes = (width * height * 3 // 8) - 69 # v3.2.0: corrected overhead capacity_bytes = (width * height * 3 // 8) - 69 # v3.2.0: corrected overhead
if dry_run: if dry_run:
result = { result = {
"image": image, "image": image,
@@ -126,7 +122,7 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
"usage_percent": round(payload_size / capacity_bytes * 100, 1), "usage_percent": round(payload_size / capacity_bytes * 100, 1),
"fits": payload_size < capacity_bytes, "fits": payload_size < capacity_bytes,
} }
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(json.dumps(result, indent=2)) click.echo(json.dumps(result, indent=2))
else: else:
@@ -137,11 +133,11 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
click.echo(f"Usage: {result['usage_percent']}%") click.echo(f"Usage: {result['usage_percent']}%")
click.echo(f"Status: {'✓ Fits' if result['fits'] else '✗ Too large'}") click.echo(f"Status: {'✓ Fits' if result['fits'] else '✗ Too large'}")
return return
# Actual encoding would happen here # Actual encoding would happen here
# For now, show what would be done # For now, show what would be done
output = output or f"{Path(image).stem}_encoded.png" output = output or f"{Path(image).stem}_encoded.png"
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(json.dumps({ click.echo(json.dumps({
"status": "success", "status": "success",
@@ -159,17 +155,17 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
@click.argument('image', type=click.Path(exists=True)) @click.argument('image', type=click.Path(exists=True))
@click.option('--passphrase', prompt=True, hide_input=True, help='Passphrase') @click.option('--passphrase', prompt=True, hide_input=True, help='Passphrase')
@click.option('--pin', prompt=True, hide_input=True, help='PIN code') @click.option('--pin', prompt=True, hide_input=True, help='PIN code')
@click.option('-o', '--output', type=click.Path(), @click.option('-o', '--output', type=click.Path(),
help='Output path for file payloads') help='Output path for file payloads')
@click.pass_context @click.pass_context
def decode(ctx, image, passphrase, pin, output): def decode(ctx, image, passphrase, pin, output):
""" """
Decode a message or file from an image. Decode a message or file from an image.
Examples: Examples:
stegasoo decode encoded.png --passphrase --pin stegasoo decode encoded.png --passphrase --pin
stegasoo decode encoded.png -o ./extracted/ stegasoo decode encoded.png -o ./extracted/
""" """
# Actual decoding would happen here # Actual decoding would happen here
@@ -179,7 +175,7 @@ def decode(ctx, image, passphrase, pin, output):
"payload_type": "text", "payload_type": "text",
"message": "[Decoded message would appear here]", "message": "[Decoded message would appear here]",
} }
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(json.dumps(result, indent=2)) click.echo(json.dumps(result, indent=2))
else: else:
@@ -222,27 +218,27 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
passphrase, pin, compress, algorithm, recursive, jobs, verbose): passphrase, pin, compress, algorithm, recursive, jobs, verbose):
""" """
Encode message into multiple images. Encode message into multiple images.
Examples: Examples:
stegasoo batch encode *.png -m "Secret" --passphrase --pin stegasoo batch encode *.png -m "Secret" --passphrase --pin
stegasoo batch encode ./photos/ -r -o ./encoded/ stegasoo batch encode ./photos/ -r -o ./encoded/
""" """
if not message and not file_payload: if not message and not file_payload:
raise click.UsageError("Either --message or --file is required") raise click.UsageError("Either --message or --file is required")
processor = BatchProcessor(max_workers=jobs) processor = BatchProcessor(max_workers=jobs)
# Progress callback # Progress callback
def progress(current, total, item): def progress(current, total, item):
if not ctx.obj.get('json'): if not ctx.obj.get('json'):
status = "" if item.status.value == "success" else "" status = "" if item.status.value == "success" else ""
click.echo(f"[{current}/{total}] {status} {item.input_path.name}") click.echo(f"[{current}/{total}] {status} {item.input_path.name}")
# v3.2.0: Use 'passphrase' key instead of 'phrase' # v3.2.0: Use 'passphrase' key instead of 'phrase'
credentials = {"passphrase": passphrase, "pin": pin} credentials = {"passphrase": passphrase, "pin": pin}
result = processor.batch_encode( result = processor.batch_encode(
images=list(images), images=list(images),
message=message, message=message,
@@ -254,7 +250,7 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
recursive=recursive, recursive=recursive,
progress_callback=progress if not ctx.obj.get('json') else None, progress_callback=progress if not ctx.obj.get('json') else None,
) )
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(result.to_json()) click.echo(result.to_json())
else: else:
@@ -275,24 +271,24 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verbose): def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verbose):
""" """
Decode messages from multiple images. Decode messages from multiple images.
Examples: Examples:
stegasoo batch decode encoded*.png --passphrase --pin stegasoo batch decode encoded*.png --passphrase --pin
stegasoo batch decode ./encoded/ -r -o ./extracted/ stegasoo batch decode ./encoded/ -r -o ./extracted/
""" """
processor = BatchProcessor(max_workers=jobs) processor = BatchProcessor(max_workers=jobs)
# Progress callback # Progress callback
def progress(current, total, item): def progress(current, total, item):
if not ctx.obj.get('json'): if not ctx.obj.get('json'):
status = "" if item.status.value == "success" else "" status = "" if item.status.value == "success" else ""
click.echo(f"[{current}/{total}] {status} {item.input_path.name}") click.echo(f"[{current}/{total}] {status} {item.input_path.name}")
# v3.2.0: Use 'passphrase' key instead of 'phrase' # v3.2.0: Use 'passphrase' key instead of 'phrase'
credentials = {"passphrase": passphrase, "pin": pin} credentials = {"passphrase": passphrase, "pin": pin}
result = processor.batch_decode( result = processor.batch_decode(
images=list(images), images=list(images),
output_dir=Path(output_dir) if output_dir else None, output_dir=Path(output_dir) if output_dir else None,
@@ -300,7 +296,7 @@ def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verb
recursive=recursive, recursive=recursive,
progress_callback=progress if not ctx.obj.get('json') else None, progress_callback=progress if not ctx.obj.get('json') else None,
) )
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(result.to_json()) click.echo(result.to_json())
else: else:
@@ -315,21 +311,21 @@ def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verb
def batch_check(ctx, images, recursive): def batch_check(ctx, images, recursive):
""" """
Check capacity of multiple images. Check capacity of multiple images.
Examples: Examples:
stegasoo batch check *.png stegasoo batch check *.png
stegasoo batch check ./photos/ -r stegasoo batch check ./photos/ -r
""" """
results = batch_capacity_check(list(images), recursive) results = batch_capacity_check(list(images), recursive)
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(json.dumps(results, indent=2)) click.echo(json.dumps(results, indent=2))
else: else:
click.echo(f"{'Image':<40} {'Size':<12} {'Capacity':<12} {'Status'}") click.echo(f"{'Image':<40} {'Size':<12} {'Capacity':<12} {'Status'}")
click.echo("" * 80) click.echo("" * 80)
for item in results: for item in results:
if 'error' in item: if 'error' in item:
click.echo(f"{Path(item['path']).name:<40} {'ERROR':<12} {'':<12} {item['error']}") click.echo(f"{Path(item['path']).name:<40} {'ERROR':<12} {'':<12} {item['error']}")
@@ -337,10 +333,10 @@ def batch_check(ctx, images, recursive):
name = Path(item['path']).name name = Path(item['path']).name
if len(name) > 38: if len(name) > 38:
name = name[:35] + "..." name = name[:35] + "..."
status = "" if item['valid'] else "" status = "" if item['valid'] else ""
warnings = ", ".join(item.get('warnings', [])) warnings = ", ".join(item.get('warnings', []))
click.echo( click.echo(
f"{name:<40} " f"{name:<40} "
f"{item['dimensions']:<12} " f"{item['dimensions']:<12} "
@@ -354,7 +350,7 @@ def batch_check(ctx, images, recursive):
# ============================================================================= # =============================================================================
@cli.command() @cli.command()
@click.option('--words', default=DEFAULT_PASSPHRASE_WORDS, @click.option('--words', default=DEFAULT_PASSPHRASE_WORDS,
help=f'Number of words in passphrase (default: {DEFAULT_PASSPHRASE_WORDS})') help=f'Number of words in passphrase (default: {DEFAULT_PASSPHRASE_WORDS})')
@click.option('--pin-length', default=DEFAULT_PIN_LENGTH, @click.option('--pin-length', default=DEFAULT_PIN_LENGTH,
help=f'PIN length (default: {DEFAULT_PIN_LENGTH})') help=f'PIN length (default: {DEFAULT_PIN_LENGTH})')
@@ -362,21 +358,21 @@ def batch_check(ctx, images, recursive):
def generate(ctx, words, pin_length): def generate(ctx, words, pin_length):
""" """
Generate random credentials (passphrase + PIN). Generate random credentials (passphrase + PIN).
Examples: Examples:
stegasoo generate stegasoo generate
stegasoo generate --words 6 --pin-length 8 stegasoo generate --words 6 --pin-length 8
""" """
import secrets import secrets
# Generate PIN # Generate PIN
pin = ''.join(str(secrets.randbelow(10)) for _ in range(pin_length)) pin = ''.join(str(secrets.randbelow(10)) for _ in range(pin_length))
# Ensure PIN doesn't start with 0 # Ensure PIN doesn't start with 0
if pin[0] == '0': if pin[0] == '0':
pin = str(secrets.randbelow(9) + 1) + pin[1:] pin = str(secrets.randbelow(9) + 1) + pin[1:]
# Generate passphrase (would use BIP-39 wordlist) # Generate passphrase (would use BIP-39 wordlist)
# Placeholder - actual implementation uses constants.get_wordlist() # Placeholder - actual implementation uses constants.get_wordlist()
try: try:
@@ -388,16 +384,16 @@ def generate(ctx, words, pin_length):
sample_words = ['alpha', 'bravo', 'charlie', 'delta', 'echo', 'foxtrot', sample_words = ['alpha', 'bravo', 'charlie', 'delta', 'echo', 'foxtrot',
'golf', 'hotel', 'india', 'juliet', 'kilo', 'lima'] 'golf', 'hotel', 'india', 'juliet', 'kilo', 'lima']
phrase_words = [secrets.choice(sample_words) for _ in range(words)] phrase_words = [secrets.choice(sample_words) for _ in range(words)]
passphrase = ' '.join(phrase_words) passphrase = ' '.join(phrase_words)
result = { result = {
"passphrase": passphrase, "passphrase": passphrase,
"pin": pin, "pin": pin,
"passphrase_words": words, "passphrase_words": words,
"pin_length": pin_length, "pin_length": pin_length,
} }
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(json.dumps(result, indent=2)) click.echo(json.dumps(result, indent=2))
else: else:
@@ -421,17 +417,17 @@ def info(ctx):
"max_file_payload_bytes": MAX_FILE_PAYLOAD_SIZE, "max_file_payload_bytes": MAX_FILE_PAYLOAD_SIZE,
}, },
} }
if ctx.obj.get('json'): if ctx.obj.get('json'):
click.echo(json.dumps(info_data, indent=2)) click.echo(json.dumps(info_data, indent=2))
else: else:
click.echo(f"Stegasoo v{__version__}") click.echo(f"Stegasoo v{__version__}")
click.echo(f"\nCompression algorithms:") click.echo("\nCompression algorithms:")
for algo in get_available_algorithms(): for algo in get_available_algorithms():
click.echo(f"{algorithm_name(algo)}") click.echo(f"{algorithm_name(algo)}")
if not HAS_LZ4: if not HAS_LZ4:
click.echo(" (install 'lz4' for LZ4 support)") click.echo(" (install 'lz4' for LZ4 support)")
click.echo(f"\nLimits:") click.echo("\nLimits:")
click.echo(f" • Max message: {MAX_MESSAGE_SIZE:,} bytes") click.echo(f" • Max message: {MAX_MESSAGE_SIZE:,} bytes")
click.echo(f" • Max file payload: {MAX_FILE_PAYLOAD_SIZE:,} bytes") click.echo(f" • Max file payload: {MAX_FILE_PAYLOAD_SIZE:,} bytes")

View File

@@ -5,10 +5,9 @@ Provides transparent compression/decompression for payloads before encryption.
Supports multiple algorithms with automatic detection on decompression. Supports multiple algorithms with automatic detection on decompression.
""" """
import zlib
import struct import struct
import zlib
from enum import IntEnum from enum import IntEnum
from typing import Optional
# Optional LZ4 support (faster, slightly worse ratio) # Optional LZ4 support (faster, slightly worse ratio)
try: try:
@@ -43,26 +42,26 @@ class CompressionError(Exception):
def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm.ZLIB) -> bytes: def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm.ZLIB) -> bytes:
""" """
Compress data with specified algorithm. Compress data with specified algorithm.
Format: MAGIC (4) + ALGORITHM (1) + ORIGINAL_SIZE (4) + COMPRESSED_DATA Format: MAGIC (4) + ALGORITHM (1) + ORIGINAL_SIZE (4) + COMPRESSED_DATA
Args: Args:
data: Raw bytes to compress data: Raw bytes to compress
algorithm: Compression algorithm to use algorithm: Compression algorithm to use
Returns: Returns:
Compressed data with header, or original data if compression didn't help Compressed data with header, or original data if compression didn't help
""" """
if len(data) < MIN_COMPRESS_SIZE: if len(data) < MIN_COMPRESS_SIZE:
# Too small to benefit from compression # Too small to benefit from compression
return _wrap_uncompressed(data) return _wrap_uncompressed(data)
if algorithm == CompressionAlgorithm.NONE: if algorithm == CompressionAlgorithm.NONE:
return _wrap_uncompressed(data) return _wrap_uncompressed(data)
elif algorithm == CompressionAlgorithm.ZLIB: elif algorithm == CompressionAlgorithm.ZLIB:
compressed = zlib.compress(data, level=ZLIB_LEVEL) compressed = zlib.compress(data, level=ZLIB_LEVEL)
elif algorithm == CompressionAlgorithm.LZ4: elif algorithm == CompressionAlgorithm.LZ4:
if not HAS_LZ4: if not HAS_LZ4:
# Fall back to zlib if LZ4 not available # Fall back to zlib if LZ4 not available
@@ -72,11 +71,11 @@ def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm
compressed = lz4.frame.compress(data) compressed = lz4.frame.compress(data)
else: else:
raise CompressionError(f"Unknown compression algorithm: {algorithm}") raise CompressionError(f"Unknown compression algorithm: {algorithm}")
# Only use compression if it actually reduced size # Only use compression if it actually reduced size
if len(compressed) >= len(data): if len(compressed) >= len(data):
return _wrap_uncompressed(data) return _wrap_uncompressed(data)
# Build header: MAGIC + algorithm + original_size + compressed_data # Build header: MAGIC + algorithm + original_size + compressed_data
header = COMPRESSION_MAGIC + struct.pack('<BI', algorithm, len(data)) header = COMPRESSION_MAGIC + struct.pack('<BI', algorithm, len(data))
return header + compressed return header + compressed
@@ -85,10 +84,10 @@ def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm
def decompress(data: bytes) -> bytes: def decompress(data: bytes) -> bytes:
""" """
Decompress data, auto-detecting algorithm from header. Decompress data, auto-detecting algorithm from header.
Args: Args:
data: Potentially compressed data data: Potentially compressed data
Returns: Returns:
Decompressed data (or original if not compressed) Decompressed data (or original if not compressed)
""" """
@@ -96,24 +95,24 @@ def decompress(data: bytes) -> bytes:
if not data.startswith(COMPRESSION_MAGIC): if not data.startswith(COMPRESSION_MAGIC):
# Not compressed by us, return as-is # Not compressed by us, return as-is
return data return data
if len(data) < 9: # MAGIC(4) + ALGO(1) + SIZE(4) if len(data) < 9: # MAGIC(4) + ALGO(1) + SIZE(4)
raise CompressionError("Truncated compression header") raise CompressionError("Truncated compression header")
# Parse header # Parse header
algorithm = CompressionAlgorithm(data[4]) algorithm = CompressionAlgorithm(data[4])
original_size = struct.unpack('<I', data[5:9])[0] original_size = struct.unpack('<I', data[5:9])[0]
compressed_data = data[9:] compressed_data = data[9:]
if algorithm == CompressionAlgorithm.NONE: if algorithm == CompressionAlgorithm.NONE:
result = compressed_data result = compressed_data
elif algorithm == CompressionAlgorithm.ZLIB: elif algorithm == CompressionAlgorithm.ZLIB:
try: try:
result = zlib.decompress(compressed_data) result = zlib.decompress(compressed_data)
except zlib.error as e: except zlib.error as e:
raise CompressionError(f"Zlib decompression failed: {e}") raise CompressionError(f"Zlib decompression failed: {e}")
elif algorithm == CompressionAlgorithm.LZ4: elif algorithm == CompressionAlgorithm.LZ4:
if not HAS_LZ4: if not HAS_LZ4:
raise CompressionError("LZ4 compression used but lz4 package not installed") raise CompressionError("LZ4 compression used but lz4 package not installed")
@@ -123,13 +122,13 @@ def decompress(data: bytes) -> bytes:
raise CompressionError(f"LZ4 decompression failed: {e}") raise CompressionError(f"LZ4 decompression failed: {e}")
else: else:
raise CompressionError(f"Unknown compression algorithm: {algorithm}") raise CompressionError(f"Unknown compression algorithm: {algorithm}")
# Verify size # Verify size
if len(result) != original_size: if len(result) != original_size:
raise CompressionError( raise CompressionError(
f"Size mismatch: expected {original_size}, got {len(result)}" f"Size mismatch: expected {original_size}, got {len(result)}"
) )
return result return result
@@ -142,7 +141,7 @@ def _wrap_uncompressed(data: bytes) -> bytes:
def get_compression_ratio(original: bytes, compressed: bytes) -> float: def get_compression_ratio(original: bytes, compressed: bytes) -> float:
""" """
Calculate compression ratio. Calculate compression ratio.
Returns: Returns:
Ratio where < 1.0 means compression helped, > 1.0 means it expanded Ratio where < 1.0 means compression helped, > 1.0 means it expanded
""" """
@@ -155,36 +154,36 @@ def estimate_compressed_size(data: bytes, algorithm: CompressionAlgorithm = Comp
""" """
Estimate compressed size without full compression. Estimate compressed size without full compression.
Uses sampling for large data. Uses sampling for large data.
Args: Args:
data: Data to estimate data: Data to estimate
algorithm: Algorithm to estimate for algorithm: Algorithm to estimate for
Returns: Returns:
Estimated compressed size in bytes Estimated compressed size in bytes
""" """
if len(data) < MIN_COMPRESS_SIZE: if len(data) < MIN_COMPRESS_SIZE:
return len(data) + 9 # Header overhead return len(data) + 9 # Header overhead
# For small data, just compress it # For small data, just compress it
if len(data) < 10000: if len(data) < 10000:
compressed = compress(data, algorithm) compressed = compress(data, algorithm)
return len(compressed) return len(compressed)
# For large data, sample and extrapolate # For large data, sample and extrapolate
sample_size = 8192 sample_size = 8192
sample = data[:sample_size] sample = data[:sample_size]
if algorithm == CompressionAlgorithm.ZLIB: if algorithm == CompressionAlgorithm.ZLIB:
compressed_sample = zlib.compress(sample, level=ZLIB_LEVEL) compressed_sample = zlib.compress(sample, level=ZLIB_LEVEL)
elif algorithm == CompressionAlgorithm.LZ4 and HAS_LZ4: elif algorithm == CompressionAlgorithm.LZ4 and HAS_LZ4:
compressed_sample = lz4.frame.compress(sample) compressed_sample = lz4.frame.compress(sample)
else: else:
compressed_sample = zlib.compress(sample, level=ZLIB_LEVEL) compressed_sample = zlib.compress(sample, level=ZLIB_LEVEL)
ratio = len(compressed_sample) / len(sample) ratio = len(compressed_sample) / len(sample)
estimated = int(len(data) * ratio) + 9 # Add header estimated = int(len(data) * ratio) + 9 # Add header
return estimated return estimated

View File

@@ -14,7 +14,6 @@ BREAKING CHANGES in v3.2.0:
- Renamed day_phrase → passphrase throughout codebase - Renamed day_phrase → passphrase throughout codebase
""" """
import os
from pathlib import Path from pathlib import Path
# ============================================================================ # ============================================================================
@@ -89,7 +88,7 @@ RECOMMENDED_PASSPHRASE_WORDS = 4 # Best practice guideline
# Legacy aliases for backward compatibility during transition # Legacy aliases for backward compatibility during transition
MIN_PHRASE_WORDS = MIN_PASSPHRASE_WORDS MIN_PHRASE_WORDS = MIN_PASSPHRASE_WORDS
MAX_PHRASE_WORDS = MAX_PASSPHRASE_WORDS MAX_PHRASE_WORDS = MAX_PASSPHRASE_WORDS
DEFAULT_PHRASE_WORDS = DEFAULT_PASSPHRASE_WORDS DEFAULT_PHRASE_WORDS = DEFAULT_PASSPHRASE_WORDS
# RSA configuration # RSA configuration
@@ -180,11 +179,11 @@ def get_data_dir() -> Path:
Path.cwd().parent / 'data', # One level up from cwd Path.cwd().parent / 'data', # One level up from cwd
Path.cwd().parent.parent / 'data', # Two levels up from cwd Path.cwd().parent.parent / 'data', # Two levels up from cwd
] ]
for path in candidates: for path in candidates:
if path.exists(): if path.exists():
return path return path
# Default to first candidate # Default to first candidate
return candidates[0] return candidates[0]
@@ -192,14 +191,14 @@ def get_data_dir() -> Path:
def get_bip39_words() -> list[str]: def get_bip39_words() -> list[str]:
"""Load BIP-39 wordlist.""" """Load BIP-39 wordlist."""
wordlist_path = get_data_dir() / 'bip39-words.txt' wordlist_path = get_data_dir() / 'bip39-words.txt'
if not wordlist_path.exists(): if not wordlist_path.exists():
raise FileNotFoundError( raise FileNotFoundError(
f"BIP-39 wordlist not found at {wordlist_path}. " f"BIP-39 wordlist not found at {wordlist_path}. "
"Please ensure bip39-words.txt is in the data directory." "Please ensure bip39-words.txt is in the data directory."
) )
with open(wordlist_path, 'r') as f: with open(wordlist_path) as f:
return [line.strip() for line in f if line.strip()] return [line.strip() for line in f if line.strip()]
@@ -240,18 +239,18 @@ DCT_BYTES_PER_PIXEL = 0.125 # Approximate for DCT mode (varies by implementatio
def detect_stego_mode(encrypted_data: bytes) -> str: def detect_stego_mode(encrypted_data: bytes) -> str:
""" """
Detect embedding mode from encrypted payload header. Detect embedding mode from encrypted payload header.
Args: Args:
encrypted_data: First few bytes of extracted payload encrypted_data: First few bytes of extracted payload
Returns: Returns:
'lsb' or 'dct' or 'unknown' 'lsb' or 'dct' or 'unknown'
""" """
if len(encrypted_data) < 4: if len(encrypted_data) < 4:
return 'unknown' return 'unknown'
header = encrypted_data[:4] header = encrypted_data[:4]
if header == b'\x89ST3': if header == b'\x89ST3':
return EMBED_MODE_LSB return EMBED_MODE_LSB
elif header == b'\x89DCT': elif header == b'\x89DCT':

View File

@@ -15,38 +15,40 @@ BREAKING CHANGES in v3.2.0:
- Renamed day_phrase → passphrase (no daily rotation needed) - Renamed day_phrase → passphrase (no daily rotation needed)
""" """
import io
import hashlib import hashlib
import io
import secrets import secrets
import struct import struct
import json
from typing import Optional, Union
from PIL import Image
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from PIL import Image
from .constants import ( from .constants import (
MAGIC_HEADER, FORMAT_VERSION, ARGON2_MEMORY_COST,
SALT_SIZE, IV_SIZE, TAG_SIZE, ARGON2_PARALLELISM,
ARGON2_TIME_COST, ARGON2_MEMORY_COST, ARGON2_PARALLELISM, ARGON2_TIME_COST,
PBKDF2_ITERATIONS, FORMAT_VERSION,
PAYLOAD_TEXT, PAYLOAD_FILE, IV_SIZE,
MAGIC_HEADER,
MAX_FILENAME_LENGTH, MAX_FILENAME_LENGTH,
PAYLOAD_FILE,
PAYLOAD_TEXT,
PBKDF2_ITERATIONS,
SALT_SIZE,
TAG_SIZE,
) )
from .models import FilePayload, DecodeResult from .exceptions import DecryptionError, EncryptionError, InvalidHeaderError, KeyDerivationError
from .exceptions import ( from .models import DecodeResult, FilePayload
EncryptionError, DecryptionError, KeyDerivationError, InvalidHeaderError
)
# Check for Argon2 availability # Check for Argon2 availability
try: try:
from argon2.low_level import hash_secret_raw, Type from argon2.low_level import Type, hash_secret_raw
HAS_ARGON2 = True HAS_ARGON2 = True
except ImportError: except ImportError:
HAS_ARGON2 = False HAS_ARGON2 = False
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
# ============================================================================= # =============================================================================
@@ -57,28 +59,28 @@ except ImportError:
CHANNEL_KEY_AUTO = "auto" CHANNEL_KEY_AUTO = "auto"
def _resolve_channel_key(channel_key: Optional[Union[str, bool]]) -> Optional[bytes]: def _resolve_channel_key(channel_key: str | bool | None) -> bytes | None:
""" """
Resolve channel key parameter to actual key hash. Resolve channel key parameter to actual key hash.
Args: Args:
channel_key: Channel key parameter with these behaviors: channel_key: Channel key parameter with these behaviors:
- None or "auto": Use server's configured key (from env/config) - None or "auto": Use server's configured key (from env/config)
- str (valid key): Use this specific key - str (valid key): Use this specific key
- "" or False: Explicitly use NO channel key (public mode) - "" or False: Explicitly use NO channel key (public mode)
Returns: Returns:
32-byte channel key hash, or None for public mode 32-byte channel key hash, or None for public mode
""" """
# Explicit public mode # Explicit public mode
if channel_key == "" or channel_key is False: if channel_key == "" or channel_key is False:
return None return None
# Auto-detect from environment/config # Auto-detect from environment/config
if channel_key is None or channel_key == CHANNEL_KEY_AUTO: if channel_key is None or channel_key == CHANNEL_KEY_AUTO:
from .channel import get_channel_key_hash from .channel import get_channel_key_hash
return get_channel_key_hash() return get_channel_key_hash()
# Explicit key provided - validate and hash it # Explicit key provided - validate and hash it
if isinstance(channel_key, str): if isinstance(channel_key, str):
from .channel import format_channel_key, validate_channel_key from .channel import format_channel_key, validate_channel_key
@@ -86,7 +88,7 @@ def _resolve_channel_key(channel_key: Optional[Union[str, bool]]) -> Optional[by
raise ValueError(f"Invalid channel key format: {channel_key}") raise ValueError(f"Invalid channel key format: {channel_key}")
formatted = format_channel_key(channel_key) formatted = format_channel_key(channel_key)
return hashlib.sha256(formatted.encode('utf-8')).digest() return hashlib.sha256(formatted.encode('utf-8')).digest()
raise ValueError(f"Invalid channel_key type: {type(channel_key)}") raise ValueError(f"Invalid channel_key type: {type(channel_key)}")
@@ -97,19 +99,19 @@ def _resolve_channel_key(channel_key: Optional[Union[str, bool]]) -> Optional[by
def hash_photo(image_data: bytes) -> bytes: def hash_photo(image_data: bytes) -> bytes:
""" """
Compute deterministic hash of photo pixel content. Compute deterministic hash of photo pixel content.
This normalizes the image to RGB and hashes the raw pixel data, This normalizes the image to RGB and hashes the raw pixel data,
making it resistant to metadata changes. making it resistant to metadata changes.
Args: Args:
image_data: Raw image file bytes image_data: Raw image file bytes
Returns: Returns:
32-byte SHA-256 hash 32-byte SHA-256 hash
""" """
img: Image.Image = Image.open(io.BytesIO(image_data)).convert('RGB') img: Image.Image = Image.open(io.BytesIO(image_data)).convert('RGB')
pixels = img.tobytes() pixels = img.tobytes()
# Double-hash with prefix for additional mixing # Double-hash with prefix for additional mixing
h = hashlib.sha256(pixels).digest() h = hashlib.sha256(pixels).digest()
h = hashlib.sha256(h + pixels[:1024]).digest() h = hashlib.sha256(h + pixels[:1024]).digest()
@@ -121,12 +123,12 @@ def derive_hybrid_key(
passphrase: str, passphrase: str,
salt: bytes, salt: bytes,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> bytes: ) -> bytes:
""" """
Derive encryption key from multiple factors. Derive encryption key from multiple factors.
Combines: Combines:
- Photo hash (something you have) - Photo hash (something you have)
- Passphrase (something you know) - Passphrase (something you know)
@@ -134,9 +136,9 @@ def derive_hybrid_key(
- RSA key (something you have) - RSA key (something you have)
- Channel key (deployment/group binding) - Channel key (deployment/group binding)
- Salt (random per message) - Salt (random per message)
Uses Argon2id if available, falls back to PBKDF2. Uses Argon2id if available, falls back to PBKDF2.
Args: Args:
photo_data: Reference photo bytes photo_data: Reference photo bytes
passphrase: Shared passphrase (recommend 4+ words) passphrase: Shared passphrase (recommend 4+ words)
@@ -147,19 +149,19 @@ def derive_hybrid_key(
- None or "auto": Use configured key - None or "auto": Use configured key
- str: Use this specific key - str: Use this specific key
- "" or False: No channel key (public mode) - "" or False: No channel key (public mode)
Returns: Returns:
32-byte derived key 32-byte derived key
Raises: Raises:
KeyDerivationError: If key derivation fails KeyDerivationError: If key derivation fails
""" """
try: try:
photo_hash = hash_photo(photo_data) photo_hash = hash_photo(photo_data)
# Resolve channel key # Resolve channel key
channel_hash = _resolve_channel_key(channel_key) channel_hash = _resolve_channel_key(channel_key)
# Build key material # Build key material
key_material = ( key_material = (
photo_hash + photo_hash +
@@ -167,15 +169,15 @@ def derive_hybrid_key(
pin.encode() + pin.encode() +
salt salt
) )
# Add RSA key hash if provided # Add RSA key hash if provided
if rsa_key_data: if rsa_key_data:
key_material += hashlib.sha256(rsa_key_data).digest() key_material += hashlib.sha256(rsa_key_data).digest()
# Add channel key hash if configured (v4.0.0) # Add channel key hash if configured (v4.0.0)
if channel_hash: if channel_hash:
key_material += channel_hash key_material += channel_hash
if HAS_ARGON2: if HAS_ARGON2:
key = hash_secret_raw( key = hash_secret_raw(
secret=key_material, secret=key_material,
@@ -195,9 +197,9 @@ def derive_hybrid_key(
backend=default_backend() backend=default_backend()
) )
key = kdf.derive(key_material) key = kdf.derive(key_material)
return key return key
except Exception as e: except Exception as e:
raise KeyDerivationError(f"Failed to derive key: {e}") from e raise KeyDerivationError(f"Failed to derive key: {e}") from e
@@ -206,61 +208,61 @@ def derive_pixel_key(
photo_data: bytes, photo_data: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> bytes: ) -> bytes:
""" """
Derive key for pseudo-random pixel selection. Derive key for pseudo-random pixel selection.
This key determines which pixels are used for embedding, This key determines which pixels are used for embedding,
making the message location unpredictable without the correct inputs. making the message location unpredictable without the correct inputs.
Args: Args:
photo_data: Reference photo bytes photo_data: Reference photo bytes
passphrase: Shared passphrase passphrase: Shared passphrase
pin: Optional static PIN pin: Optional static PIN
rsa_key_data: Optional RSA key bytes rsa_key_data: Optional RSA key bytes
channel_key: Channel key parameter (see derive_hybrid_key) channel_key: Channel key parameter (see derive_hybrid_key)
Returns: Returns:
32-byte key for pixel selection 32-byte key for pixel selection
""" """
photo_hash = hash_photo(photo_data) photo_hash = hash_photo(photo_data)
# Resolve channel key # Resolve channel key
channel_hash = _resolve_channel_key(channel_key) channel_hash = _resolve_channel_key(channel_key)
material = ( material = (
photo_hash + photo_hash +
passphrase.lower().encode() + passphrase.lower().encode() +
pin.encode() pin.encode()
) )
if rsa_key_data: if rsa_key_data:
material += hashlib.sha256(rsa_key_data).digest() material += hashlib.sha256(rsa_key_data).digest()
# Add channel key hash if configured (v4.0.0) # Add channel key hash if configured (v4.0.0)
if channel_hash: if channel_hash:
material += channel_hash material += channel_hash
return hashlib.sha256(material + b"pixel_selection").digest() return hashlib.sha256(material + b"pixel_selection").digest()
def _pack_payload( def _pack_payload(
content: Union[str, bytes, FilePayload], content: str | bytes | FilePayload,
) -> tuple[bytes, int]: ) -> tuple[bytes, int]:
""" """
Pack payload with type marker and metadata. Pack payload with type marker and metadata.
Format for text: Format for text:
[type:1][data] [type:1][data]
Format for file: Format for file:
[type:1][filename_len:2][filename][mime_len:2][mime][data] [type:1][filename_len:2][filename][mime_len:2][mime][data]
Args: Args:
content: Text string, raw bytes, or FilePayload content: Text string, raw bytes, or FilePayload
Returns: Returns:
Tuple of (packed bytes, payload type) Tuple of (packed bytes, payload type)
""" """
@@ -268,12 +270,12 @@ def _pack_payload(
# Text message # Text message
data = content.encode('utf-8') data = content.encode('utf-8')
return bytes([PAYLOAD_TEXT]) + data, PAYLOAD_TEXT return bytes([PAYLOAD_TEXT]) + data, PAYLOAD_TEXT
elif isinstance(content, FilePayload): elif isinstance(content, FilePayload):
# File with metadata # File with metadata
filename = content.filename[:MAX_FILENAME_LENGTH].encode('utf-8') filename = content.filename[:MAX_FILENAME_LENGTH].encode('utf-8')
mime = (content.mime_type or '')[:100].encode('utf-8') mime = (content.mime_type or '')[:100].encode('utf-8')
packed = ( packed = (
bytes([PAYLOAD_FILE]) + bytes([PAYLOAD_FILE]) +
struct.pack('>H', len(filename)) + struct.pack('>H', len(filename)) +
@@ -283,7 +285,7 @@ def _pack_payload(
content.data content.data
) )
return packed, PAYLOAD_FILE return packed, PAYLOAD_FILE
else: else:
# Raw bytes - treat as file with no name # Raw bytes - treat as file with no name
packed = ( packed = (
@@ -298,49 +300,49 @@ def _pack_payload(
def _unpack_payload(data: bytes) -> DecodeResult: def _unpack_payload(data: bytes) -> DecodeResult:
""" """
Unpack payload and extract content with metadata. Unpack payload and extract content with metadata.
Args: Args:
data: Packed payload bytes data: Packed payload bytes
Returns: Returns:
DecodeResult with appropriate content DecodeResult with appropriate content
""" """
if len(data) < 1: if len(data) < 1:
raise DecryptionError("Empty payload") raise DecryptionError("Empty payload")
payload_type = data[0] payload_type = data[0]
if payload_type == PAYLOAD_TEXT: if payload_type == PAYLOAD_TEXT:
# Text message # Text message
text = data[1:].decode('utf-8') text = data[1:].decode('utf-8')
return DecodeResult(payload_type='text', message=text) return DecodeResult(payload_type='text', message=text)
elif payload_type == PAYLOAD_FILE: elif payload_type == PAYLOAD_FILE:
# File with metadata # File with metadata
offset = 1 offset = 1
# Read filename # Read filename
filename_len = struct.unpack('>H', data[offset:offset+2])[0] filename_len = struct.unpack('>H', data[offset:offset+2])[0]
offset += 2 offset += 2
filename = data[offset:offset+filename_len].decode('utf-8') if filename_len else None filename = data[offset:offset+filename_len].decode('utf-8') if filename_len else None
offset += filename_len offset += filename_len
# Read mime type # Read mime type
mime_len = struct.unpack('>H', data[offset:offset+2])[0] mime_len = struct.unpack('>H', data[offset:offset+2])[0]
offset += 2 offset += 2
mime_type = data[offset:offset+mime_len].decode('utf-8') if mime_len else None mime_type = data[offset:offset+mime_len].decode('utf-8') if mime_len else None
offset += mime_len offset += mime_len
# Rest is file data # Rest is file data
file_data = data[offset:] file_data = data[offset:]
return DecodeResult( return DecodeResult(
payload_type='file', payload_type='file',
file_data=file_data, file_data=file_data,
filename=filename, filename=filename,
mime_type=mime_type mime_type=mime_type
) )
else: else:
# Unknown type - try to decode as text (backward compatibility) # Unknown type - try to decode as text (backward compatibility)
try: try:
@@ -359,16 +361,16 @@ FLAG_CHANNEL_KEY = 0x01 # Set if encoded with a channel key
def encrypt_message( def encrypt_message(
message: Union[str, bytes, FilePayload], message: str | bytes | FilePayload,
photo_data: bytes, photo_data: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> bytes: ) -> bytes:
""" """
Encrypt message or file using AES-256-GCM with hybrid key derivation. Encrypt message or file using AES-256-GCM with hybrid key derivation.
Message format (v4.0.0 - with channel key support): Message format (v4.0.0 - with channel key support):
- Magic header (4 bytes) - Magic header (4 bytes)
- Version (1 byte) = 5 - Version (1 byte) = 5
@@ -377,7 +379,7 @@ def encrypt_message(
- IV (12 bytes) - IV (12 bytes)
- Auth tag (16 bytes) - Auth tag (16 bytes)
- Ciphertext (variable, padded) - Ciphertext (variable, padded)
Args: Args:
message: Message string, raw bytes, or FilePayload to encrypt message: Message string, raw bytes, or FilePayload to encrypt
photo_data: Reference photo bytes photo_data: Reference photo bytes
@@ -386,12 +388,12 @@ def encrypt_message(
rsa_key_data: Optional RSA key bytes rsa_key_data: Optional RSA key bytes
channel_key: Channel key parameter: channel_key: Channel key parameter:
- None or "auto": Use configured key - None or "auto": Use configured key
- str: Use this specific key - str: Use this specific key
- "" or False: No channel key (public mode) - "" or False: No channel key (public mode)
Returns: Returns:
Encrypted message bytes Encrypted message bytes
Raises: Raises:
EncryptionError: If encryption fails EncryptionError: If encryption fails
""" """
@@ -399,32 +401,32 @@ def encrypt_message(
salt = secrets.token_bytes(SALT_SIZE) salt = secrets.token_bytes(SALT_SIZE)
key = derive_hybrid_key(photo_data, passphrase, salt, pin, rsa_key_data, channel_key) key = derive_hybrid_key(photo_data, passphrase, salt, pin, rsa_key_data, channel_key)
iv = secrets.token_bytes(IV_SIZE) iv = secrets.token_bytes(IV_SIZE)
# Determine flags # Determine flags
flags = 0 flags = 0
channel_hash = _resolve_channel_key(channel_key) channel_hash = _resolve_channel_key(channel_key)
if channel_hash: if channel_hash:
flags |= FLAG_CHANNEL_KEY flags |= FLAG_CHANNEL_KEY
# Pack payload with type marker # Pack payload with type marker
packed_payload, _ = _pack_payload(message) packed_payload, _ = _pack_payload(message)
# Random padding to hide message length # Random padding to hide message length
padding_len = secrets.randbelow(256) + 64 padding_len = secrets.randbelow(256) + 64
padded_len = ((len(packed_payload) + padding_len + 255) // 256) * 256 padded_len = ((len(packed_payload) + padding_len + 255) // 256) * 256
padding_needed = padded_len - len(packed_payload) padding_needed = padded_len - len(packed_payload)
padding = secrets.token_bytes(padding_needed - 4) + struct.pack('>I', len(packed_payload)) padding = secrets.token_bytes(padding_needed - 4) + struct.pack('>I', len(packed_payload))
padded_message = packed_payload + padding padded_message = packed_payload + padding
# Build header for AAD # Build header for AAD
header = MAGIC_HEADER + bytes([FORMAT_VERSION, flags]) header = MAGIC_HEADER + bytes([FORMAT_VERSION, flags])
# Encrypt with AES-256-GCM # Encrypt with AES-256-GCM
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend()) cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
encryptor.authenticate_additional_data(header) encryptor.authenticate_additional_data(header)
ciphertext = encryptor.update(padded_message) + encryptor.finalize() ciphertext = encryptor.update(padded_message) + encryptor.finalize()
# v4.0.0: Header with flags byte # v4.0.0: Header with flags byte
return ( return (
header + header +
@@ -433,34 +435,34 @@ def encrypt_message(
encryptor.tag + encryptor.tag +
ciphertext ciphertext
) )
except Exception as e: except Exception as e:
raise EncryptionError(f"Encryption failed: {e}") from e raise EncryptionError(f"Encryption failed: {e}") from e
def parse_header(encrypted_data: bytes) -> Optional[dict]: def parse_header(encrypted_data: bytes) -> dict | None:
""" """
Parse the header from encrypted data. Parse the header from encrypted data.
v4.0.0: Includes flags byte for channel key indicator. v4.0.0: Includes flags byte for channel key indicator.
Args: Args:
encrypted_data: Raw encrypted bytes encrypted_data: Raw encrypted bytes
Returns: Returns:
Dict with salt, iv, tag, ciphertext, flags or None if invalid Dict with salt, iv, tag, ciphertext, flags or None if invalid
""" """
# Min size: Magic(4) + Version(1) + Flags(1) + Salt(32) + IV(12) + Tag(16) = 66 bytes # Min size: Magic(4) + Version(1) + Flags(1) + Salt(32) + IV(12) + Tag(16) = 66 bytes
if len(encrypted_data) < 66 or encrypted_data[:4] != MAGIC_HEADER: if len(encrypted_data) < 66 or encrypted_data[:4] != MAGIC_HEADER:
return None return None
try: try:
version = encrypted_data[4] version = encrypted_data[4]
if version != FORMAT_VERSION: if version != FORMAT_VERSION:
return None return None
flags = encrypted_data[5] flags = encrypted_data[5]
offset = 6 offset = 6
salt = encrypted_data[offset:offset + SALT_SIZE] salt = encrypted_data[offset:offset + SALT_SIZE]
offset += SALT_SIZE offset += SALT_SIZE
@@ -469,7 +471,7 @@ def parse_header(encrypted_data: bytes) -> Optional[dict]:
tag = encrypted_data[offset:offset + TAG_SIZE] tag = encrypted_data[offset:offset + TAG_SIZE]
offset += TAG_SIZE offset += TAG_SIZE
ciphertext = encrypted_data[offset:] ciphertext = encrypted_data[offset:]
return { return {
'version': version, 'version': version,
'flags': flags, 'flags': flags,
@@ -488,12 +490,12 @@ def decrypt_message(
photo_data: bytes, photo_data: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> DecodeResult: ) -> DecodeResult:
""" """
Decrypt message (v4.0.0 - with channel key support). Decrypt message (v4.0.0 - with channel key support).
Args: Args:
encrypted_data: Encrypted message bytes encrypted_data: Encrypted message bytes
photo_data: Reference photo bytes photo_data: Reference photo bytes
@@ -501,10 +503,10 @@ def decrypt_message(
pin: Optional static PIN pin: Optional static PIN
rsa_key_data: Optional RSA key bytes rsa_key_data: Optional RSA key bytes
channel_key: Channel key parameter (see encrypt_message) channel_key: Channel key parameter (see encrypt_message)
Returns: Returns:
DecodeResult with decrypted content DecodeResult with decrypted content
Raises: Raises:
InvalidHeaderError: If data doesn't have valid Stegasoo header InvalidHeaderError: If data doesn't have valid Stegasoo header
DecryptionError: If decryption fails (wrong credentials) DecryptionError: If decryption fails (wrong credentials)
@@ -512,20 +514,20 @@ def decrypt_message(
header = parse_header(encrypted_data) header = parse_header(encrypted_data)
if not header: if not header:
raise InvalidHeaderError("Invalid or missing Stegasoo header") raise InvalidHeaderError("Invalid or missing Stegasoo header")
# Check for channel key mismatch and provide helpful error # Check for channel key mismatch and provide helpful error
channel_hash = _resolve_channel_key(channel_key) channel_hash = _resolve_channel_key(channel_key)
has_configured_key = channel_hash is not None has_configured_key = channel_hash is not None
message_has_key = header['has_channel_key'] message_has_key = header['has_channel_key']
try: try:
key = derive_hybrid_key( key = derive_hybrid_key(
photo_data, passphrase, header['salt'], pin, rsa_key_data, channel_key photo_data, passphrase, header['salt'], pin, rsa_key_data, channel_key
) )
# Reconstruct header for AAD verification # Reconstruct header for AAD verification
aad_header = MAGIC_HEADER + bytes([FORMAT_VERSION, header['flags']]) aad_header = MAGIC_HEADER + bytes([FORMAT_VERSION, header['flags']])
cipher = Cipher( cipher = Cipher(
algorithms.AES(key), algorithms.AES(key),
modes.GCM(header['iv'], header['tag']), modes.GCM(header['iv'], header['tag']),
@@ -533,15 +535,15 @@ def decrypt_message(
) )
decryptor = cipher.decryptor() decryptor = cipher.decryptor()
decryptor.authenticate_additional_data(aad_header) decryptor.authenticate_additional_data(aad_header)
padded_plaintext = decryptor.update(header['ciphertext']) + decryptor.finalize() padded_plaintext = decryptor.update(header['ciphertext']) + decryptor.finalize()
original_length = struct.unpack('>I', padded_plaintext[-4:])[0] original_length = struct.unpack('>I', padded_plaintext[-4:])[0]
payload_data = padded_plaintext[:original_length] payload_data = padded_plaintext[:original_length]
result = _unpack_payload(payload_data) result = _unpack_payload(payload_data)
return result return result
except Exception as e: except Exception as e:
# Provide more helpful error message for channel key issues # Provide more helpful error message for channel key issues
if message_has_key and not has_configured_key: if message_has_key and not has_configured_key:
@@ -566,14 +568,14 @@ def decrypt_message_text(
photo_data: bytes, photo_data: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> str: ) -> str:
""" """
Decrypt message and return as text string. Decrypt message and return as text string.
For backward compatibility - returns text content or raises error for files. For backward compatibility - returns text content or raises error for files.
Args: Args:
encrypted_data: Encrypted message bytes encrypted_data: Encrypted message bytes
photo_data: Reference photo bytes photo_data: Reference photo bytes
@@ -581,15 +583,15 @@ def decrypt_message_text(
pin: Optional static PIN pin: Optional static PIN
rsa_key_data: Optional RSA key bytes rsa_key_data: Optional RSA key bytes
channel_key: Channel key parameter channel_key: Channel key parameter
Returns: Returns:
Decrypted message string Decrypted message string
Raises: Raises:
DecryptionError: If decryption fails or content is a file DecryptionError: If decryption fails or content is a file
""" """
result = decrypt_message(encrypted_data, photo_data, passphrase, pin, rsa_key_data, channel_key) result = decrypt_message(encrypted_data, photo_data, passphrase, pin, rsa_key_data, channel_key)
if result.is_file: if result.is_file:
if result.file_data: if result.file_data:
# Try to decode as text # Try to decode as text
@@ -600,7 +602,7 @@ def decrypt_message_text(
f"Content is a binary file ({result.filename or 'unnamed'}), not text" f"Content is a binary file ({result.filename or 'unnamed'}), not text"
) )
return "" return ""
return result.message or "" return result.message or ""
@@ -613,10 +615,10 @@ def has_argon2() -> bool:
# CHANNEL KEY UTILITIES (exposed for convenience) # CHANNEL KEY UTILITIES (exposed for convenience)
# ============================================================================= # =============================================================================
def get_active_channel_key() -> Optional[str]: def get_active_channel_key() -> str | None:
""" """
Get the currently configured channel key (if any). Get the currently configured channel key (if any).
Returns: Returns:
Formatted channel key string, or None if not configured Formatted channel key string, or None if not configured
""" """
@@ -624,7 +626,7 @@ def get_active_channel_key() -> Optional[str]:
return get_channel_key() return get_channel_key()
def get_channel_fingerprint(key: Optional[str] = None) -> Optional[str]: def get_channel_fingerprint(key: str | None = None) -> str | None:
""" """
Get a display-safe fingerprint of a channel key. Get a display-safe fingerprint of a channel key.

View File

@@ -14,12 +14,11 @@ v3.2.0-patch2 Changes:
Requires: scipy (for PNG mode), optionally jpegio (for JPEG mode) Requires: scipy (for PNG mode), optionally jpegio (for JPEG mode)
""" """
import gc
import hashlib
import io import io
import struct import struct
import hashlib
import gc
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple
from enum import Enum from enum import Enum
import numpy as np import numpy as np
@@ -103,7 +102,7 @@ class DCTEmbedStats:
color_mode: str = 'grayscale' color_mode: str = 'grayscale'
@dataclass @dataclass
class DCTCapacityInfo: class DCTCapacityInfo:
width: int width: int
height: int height: int
@@ -147,19 +146,19 @@ def _safe_dct2(block: np.ndarray) -> np.ndarray:
""" """
# Create a brand new array (not a view) # Create a brand new array (not a view)
safe_block = np.array(block, dtype=np.float64, copy=True, order='C') safe_block = np.array(block, dtype=np.float64, copy=True, order='C')
# First DCT on columns (transpose -> DCT rows -> transpose back) # First DCT on columns (transpose -> DCT rows -> transpose back)
temp = np.zeros_like(safe_block, dtype=np.float64, order='C') temp = np.zeros_like(safe_block, dtype=np.float64, order='C')
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
col = np.array(safe_block[:, i], dtype=np.float64, copy=True) col = np.array(safe_block[:, i], dtype=np.float64, copy=True)
temp[:, i] = dct(col, norm='ortho') temp[:, i] = dct(col, norm='ortho')
# Second DCT on rows # Second DCT on rows
result = np.zeros_like(temp, dtype=np.float64, order='C') result = np.zeros_like(temp, dtype=np.float64, order='C')
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
row = np.array(temp[i, :], dtype=np.float64, copy=True) row = np.array(temp[i, :], dtype=np.float64, copy=True)
result[i, :] = dct(row, norm='ortho') result[i, :] = dct(row, norm='ortho')
return result return result
@@ -170,19 +169,19 @@ def _safe_idct2(block: np.ndarray) -> np.ndarray:
""" """
# Create a brand new array (not a view) # Create a brand new array (not a view)
safe_block = np.array(block, dtype=np.float64, copy=True, order='C') safe_block = np.array(block, dtype=np.float64, copy=True, order='C')
# First IDCT on rows # First IDCT on rows
temp = np.zeros_like(safe_block, dtype=np.float64, order='C') temp = np.zeros_like(safe_block, dtype=np.float64, order='C')
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
row = np.array(safe_block[i, :], dtype=np.float64, copy=True) row = np.array(safe_block[i, :], dtype=np.float64, copy=True)
temp[i, :] = idct(row, norm='ortho') temp[i, :] = idct(row, norm='ortho')
# Second IDCT on columns # Second IDCT on columns
result = np.zeros_like(temp, dtype=np.float64, order='C') result = np.zeros_like(temp, dtype=np.float64, order='C')
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
col = np.array(temp[:, i], dtype=np.float64, copy=True) col = np.array(temp[:, i], dtype=np.float64, copy=True)
result[:, i] = idct(col, norm='ortho') result[:, i] = idct(col, norm='ortho')
return result return result
@@ -200,23 +199,23 @@ def _extract_y_channel(image_data: bytes) -> np.ndarray:
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
if img.mode != 'RGB': if img.mode != 'RGB':
img = img.convert('RGB') img = img.convert('RGB')
rgb = np.array(img, dtype=np.float64, copy=True, order='C') rgb = np.array(img, dtype=np.float64, copy=True, order='C')
Y = 0.299 * rgb[:, :, 0] + 0.587 * rgb[:, :, 1] + 0.114 * rgb[:, :, 2] Y = 0.299 * rgb[:, :, 0] + 0.587 * rgb[:, :, 1] + 0.114 * rgb[:, :, 2]
return np.array(Y, dtype=np.float64, copy=True, order='C') return np.array(Y, dtype=np.float64, copy=True, order='C')
def _pad_to_blocks(image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]: def _pad_to_blocks(image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
h, w = image.shape h, w = image.shape
new_h = ((h + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE new_h = ((h + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
new_w = ((w + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE new_w = ((w + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
if new_h == h and new_w == w: if new_h == h and new_w == w:
return np.array(image, dtype=np.float64, copy=True, order='C'), (h, w) return np.array(image, dtype=np.float64, copy=True, order='C'), (h, w)
padded = np.zeros((new_h, new_w), dtype=np.float64, order='C') padded = np.zeros((new_h, new_w), dtype=np.float64, order='C')
padded[:h, :w] = image padded[:h, :w] = image
# Simple edge replication for padding # Simple edge replication for padding
if new_h > h: if new_h > h:
for i in range(h, new_h): for i in range(h, new_h):
@@ -226,11 +225,11 @@ def _pad_to_blocks(image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
padded[:h, j] = padded[:h, w-1] padded[:h, j] = padded[:h, w-1]
if new_h > h and new_w > w: if new_h > h and new_w > w:
padded[h:, w:] = padded[h-1, w-1] padded[h:, w:] = padded[h-1, w-1]
return padded, (h, w) return padded, (h, w)
def _unpad_image(image: np.ndarray, original_size: Tuple[int, int]) -> np.ndarray: def _unpad_image(image: np.ndarray, original_size: tuple[int, int]) -> np.ndarray:
h, w = original_size h, w = original_size
return np.array(image[:h, :w], dtype=np.float64, copy=True, order='C') return np.array(image[:h, :w], dtype=np.float64, copy=True, order='C')
@@ -263,7 +262,7 @@ def _save_stego_image(image: np.ndarray, output_format: str = OUTPUT_FORMAT_PNG)
img = Image.fromarray(clipped, mode='L') img = Image.fromarray(clipped, mode='L')
buffer = io.BytesIO() buffer = io.BytesIO()
if output_format == OUTPUT_FORMAT_JPEG: if output_format == OUTPUT_FORMAT_JPEG:
img.save(buffer, format='JPEG', quality=JPEG_OUTPUT_QUALITY, img.save(buffer, format='JPEG', quality=JPEG_OUTPUT_QUALITY,
subsampling=0, optimize=True) subsampling=0, optimize=True)
else: else:
img.save(buffer, format='PNG', optimize=True) img.save(buffer, format='PNG', optimize=True)
@@ -282,15 +281,15 @@ def _save_color_image(rgb_array: np.ndarray, output_format: str = OUTPUT_FORMAT_
return buffer.getvalue() return buffer.getvalue()
def _rgb_to_ycbcr(rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: def _rgb_to_ycbcr(rgb: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
R = rgb[:, :, 0].astype(np.float64) R = rgb[:, :, 0].astype(np.float64)
G = rgb[:, :, 1].astype(np.float64) G = rgb[:, :, 1].astype(np.float64)
B = rgb[:, :, 2].astype(np.float64) B = rgb[:, :, 2].astype(np.float64)
Y = np.array(0.299 * R + 0.587 * G + 0.114 * B, dtype=np.float64, copy=True, order='C') Y = np.array(0.299 * R + 0.587 * G + 0.114 * B, dtype=np.float64, copy=True, order='C')
Cb = np.array(128 - 0.168736 * R - 0.331264 * G + 0.5 * B, dtype=np.float64, copy=True, order='C') Cb = np.array(128 - 0.168736 * R - 0.331264 * G + 0.5 * B, dtype=np.float64, copy=True, order='C')
Cr = np.array(128 + 0.5 * R - 0.418688 * G - 0.081312 * B, dtype=np.float64, copy=True, order='C') Cr = np.array(128 + 0.5 * R - 0.418688 * G - 0.081312 * B, dtype=np.float64, copy=True, order='C')
return Y, Cb, Cr return Y, Cb, Cr
@@ -298,7 +297,7 @@ def _ycbcr_to_rgb(Y: np.ndarray, Cb: np.ndarray, Cr: np.ndarray) -> np.ndarray:
R = Y + 1.402 * (Cr - 128) R = Y + 1.402 * (Cr - 128)
G = Y - 0.344136 * (Cb - 128) - 0.714136 * (Cr - 128) G = Y - 0.344136 * (Cb - 128) - 0.714136 * (Cr - 128)
B = Y + 1.772 * (Cb - 128) B = Y + 1.772 * (Cb - 128)
rgb = np.zeros((Y.shape[0], Y.shape[1], 3), dtype=np.float64, order='C') rgb = np.zeros((Y.shape[0], Y.shape[1], 3), dtype=np.float64, order='C')
rgb[:, :, 0] = R rgb[:, :, 0] = R
rgb[:, :, 1] = G rgb[:, :, 1] = G
@@ -310,20 +309,20 @@ def _create_header(data_length: int, flags: int = 0) -> bytes:
return struct.pack('>4sBBI', DCT_MAGIC, 1, flags, data_length) return struct.pack('>4sBBI', DCT_MAGIC, 1, flags, data_length)
def _parse_header(header_bits: list) -> Tuple[int, int, int]: def _parse_header(header_bits: list) -> tuple[int, int, int]:
if len(header_bits) < HEADER_SIZE * 8: if len(header_bits) < HEADER_SIZE * 8:
raise ValueError("Insufficient header data") raise ValueError("Insufficient header data")
header_bytes = bytes([ header_bytes = bytes([
sum(header_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8)) sum(header_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
for i in range(HEADER_SIZE) for i in range(HEADER_SIZE)
]) ])
magic, version, flags, length = struct.unpack('>4sBBI', header_bytes) magic, version, flags, length = struct.unpack('>4sBBI', header_bytes)
if magic != DCT_MAGIC: if magic != DCT_MAGIC:
raise ValueError("Invalid DCT stego magic bytes") raise ValueError("Invalid DCT stego magic bytes")
return version, flags, length return version, flags, length
@@ -332,8 +331,8 @@ def _parse_header(header_bits: list) -> Tuple[int, int, int]:
# ============================================================================ # ============================================================================
def _jpegio_bytes_to_file(data: bytes, suffix: str = '.jpg') -> str: def _jpegio_bytes_to_file(data: bytes, suffix: str = '.jpg') -> str:
import tempfile
import os import os
import tempfile
fd, path = tempfile.mkstemp(suffix=suffix) fd, path = tempfile.mkstemp(suffix=suffix)
try: try:
os.write(fd, data) os.write(fd, data)
@@ -366,7 +365,7 @@ def _jpegio_create_header(data_length: int, flags: int = 0) -> bytes:
return struct.pack('>4sBBI', JPEGIO_MAGIC, 1, flags, data_length) return struct.pack('>4sBBI', JPEGIO_MAGIC, 1, flags, data_length)
def _jpegio_parse_header(header_bytes: bytes) -> Tuple[int, int, int]: def _jpegio_parse_header(header_bytes: bytes) -> tuple[int, int, int]:
if len(header_bytes) < HEADER_SIZE: if len(header_bytes) < HEADER_SIZE:
raise ValueError("Insufficient header data") raise ValueError("Insufficient header data")
magic, version, flags, length = struct.unpack('>4sBBI', header_bytes[:HEADER_SIZE]) magic, version, flags, length = struct.unpack('>4sBBI', header_bytes[:HEADER_SIZE])
@@ -382,21 +381,21 @@ def _jpegio_parse_header(header_bytes: bytes) -> Tuple[int, int, int]:
def calculate_dct_capacity(image_data: bytes) -> DCTCapacityInfo: def calculate_dct_capacity(image_data: bytes) -> DCTCapacityInfo:
"""Calculate DCT embedding capacity of an image.""" """Calculate DCT embedding capacity of an image."""
_check_scipy() _check_scipy()
# Just get dimensions, don't process anything # Just get dimensions, don't process anything
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
width, height = img.size width, height = img.size
img.close() # Explicitly close img.close() # Explicitly close
blocks_x = width // BLOCK_SIZE blocks_x = width // BLOCK_SIZE
blocks_y = height // BLOCK_SIZE blocks_y = height // BLOCK_SIZE
total_blocks = blocks_x * blocks_y total_blocks = blocks_x * blocks_y
bits_per_block = len(DEFAULT_EMBED_POSITIONS) bits_per_block = len(DEFAULT_EMBED_POSITIONS)
total_bits = total_blocks * bits_per_block total_bits = total_blocks * bits_per_block
total_bytes = total_bits // 8 total_bytes = total_bits // 8
usable_bytes = max(0, total_bytes - HEADER_SIZE) usable_bytes = max(0, total_bytes - HEADER_SIZE)
return DCTCapacityInfo( return DCTCapacityInfo(
width=width, width=width,
height=height, height=height,
@@ -420,13 +419,13 @@ def estimate_capacity_comparison(image_data: bytes) -> dict:
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
width, height = img.size width, height = img.size
img.close() img.close()
pixels = width * height pixels = width * height
lsb_bytes = (pixels * 3) // 8 lsb_bytes = (pixels * 3) // 8
blocks = (width // 8) * (height // 8) blocks = (width // 8) * (height // 8)
dct_bytes = (blocks * 16) // 8 - HEADER_SIZE dct_bytes = (blocks * 16) // 8 - HEADER_SIZE
return { return {
'width': width, 'width': width,
'height': height, 'height': height,
@@ -455,17 +454,17 @@ def embed_in_dct(
seed: bytes, seed: bytes,
output_format: str = OUTPUT_FORMAT_PNG, output_format: str = OUTPUT_FORMAT_PNG,
color_mode: str = 'color', color_mode: str = 'color',
) -> Tuple[bytes, DCTEmbedStats]: ) -> tuple[bytes, DCTEmbedStats]:
"""Embed data using DCT coefficient modification.""" """Embed data using DCT coefficient modification."""
if output_format not in (OUTPUT_FORMAT_PNG, OUTPUT_FORMAT_JPEG): if output_format not in (OUTPUT_FORMAT_PNG, OUTPUT_FORMAT_JPEG):
raise ValueError(f"Invalid output format: {output_format}") raise ValueError(f"Invalid output format: {output_format}")
if color_mode not in ('color', 'grayscale'): if color_mode not in ('color', 'grayscale'):
color_mode = 'color' color_mode = 'color'
if output_format == OUTPUT_FORMAT_JPEG and HAS_JPEGIO: if output_format == OUTPUT_FORMAT_JPEG and HAS_JPEGIO:
return _embed_jpegio(data, carrier_image, seed, color_mode) return _embed_jpegio(data, carrier_image, seed, color_mode)
_check_scipy() _check_scipy()
return _embed_scipy_dct_safe(data, carrier_image, seed, output_format, color_mode) return _embed_scipy_dct_safe(data, carrier_image, seed, output_format, color_mode)
@@ -476,27 +475,27 @@ def _embed_scipy_dct_safe(
seed: bytes, seed: bytes,
output_format: str, output_format: str,
color_mode: str = 'color', color_mode: str = 'color',
) -> Tuple[bytes, DCTEmbedStats]: ) -> tuple[bytes, DCTEmbedStats]:
""" """
Embed using scipy DCT with safe memory handling. Embed using scipy DCT with safe memory handling.
Uses row-by-row 1D DCT operations instead of 2D arrays to avoid Uses row-by-row 1D DCT operations instead of 2D arrays to avoid
scipy memory corruption issues with large images. scipy memory corruption issues with large images.
""" """
capacity_info = calculate_dct_capacity(carrier_image) capacity_info = calculate_dct_capacity(carrier_image)
if len(data) > capacity_info.usable_capacity_bytes: if len(data) > capacity_info.usable_capacity_bytes:
raise ValueError( raise ValueError(
f"Data too large ({len(data)} bytes) for carrier " f"Data too large ({len(data)} bytes) for carrier "
f"(capacity: {capacity_info.usable_capacity_bytes} bytes)" f"(capacity: {capacity_info.usable_capacity_bytes} bytes)"
) )
# Load image # Load image
img = Image.open(io.BytesIO(carrier_image)) img = Image.open(io.BytesIO(carrier_image))
width, height = img.size width, height = img.size
flags = FLAG_COLOR_MODE if color_mode == 'color' else 0 flags = FLAG_COLOR_MODE if color_mode == 'color' else 0
# Prepare payload bits # Prepare payload bits
header = _create_header(len(data), flags) header = _create_header(len(data), flags)
payload = header + data payload = header + data
@@ -504,41 +503,41 @@ def _embed_scipy_dct_safe(
for byte in payload: for byte in payload:
for i in range(7, -1, -1): for i in range(7, -1, -1):
bits.append((byte >> i) & 1) bits.append((byte >> i) & 1)
# Generate block order # Generate block order
num_blocks = capacity_info.total_blocks num_blocks = capacity_info.total_blocks
block_order = _generate_block_order(num_blocks, seed) block_order = _generate_block_order(num_blocks, seed)
blocks_x = width // BLOCK_SIZE blocks_x = width // BLOCK_SIZE
if color_mode == 'color' and img.mode in ('RGB', 'RGBA'): if color_mode == 'color' and img.mode in ('RGB', 'RGBA'):
if img.mode == 'RGBA': if img.mode == 'RGBA':
img = img.convert('RGB') img = img.convert('RGB')
# Process color image # Process color image
rgb = np.array(img, dtype=np.float64, copy=True, order='C') rgb = np.array(img, dtype=np.float64, copy=True, order='C')
img.close() img.close()
Y, Cb, Cr = _rgb_to_ycbcr(rgb) Y, Cb, Cr = _rgb_to_ycbcr(rgb)
del rgb del rgb
gc.collect() gc.collect()
Y_padded, original_size = _pad_to_blocks(Y) Y_padded, original_size = _pad_to_blocks(Y)
del Y del Y
gc.collect() gc.collect()
# Embed in Y channel # Embed in Y channel
Y_embedded = _embed_in_channel_safe(Y_padded, bits, block_order, blocks_x) Y_embedded = _embed_in_channel_safe(Y_padded, bits, block_order, blocks_x)
del Y_padded del Y_padded
gc.collect() gc.collect()
Y_result = _unpad_image(Y_embedded, original_size) Y_result = _unpad_image(Y_embedded, original_size)
del Y_embedded del Y_embedded
gc.collect() gc.collect()
result_rgb = _ycbcr_to_rgb(Y_result, Cb, Cr) result_rgb = _ycbcr_to_rgb(Y_result, Cb, Cr)
del Y_result, Cb, Cr del Y_result, Cb, Cr
gc.collect() gc.collect()
stego_bytes = _save_color_image(result_rgb, output_format) stego_bytes = _save_color_image(result_rgb, output_format)
del result_rgb del result_rgb
gc.collect() gc.collect()
@@ -546,23 +545,23 @@ def _embed_scipy_dct_safe(
# Grayscale mode # Grayscale mode
image = _to_grayscale(carrier_image) image = _to_grayscale(carrier_image)
img.close() img.close()
padded, original_size = _pad_to_blocks(image) padded, original_size = _pad_to_blocks(image)
del image del image
gc.collect() gc.collect()
embedded = _embed_in_channel_safe(padded, bits, block_order, blocks_x) embedded = _embed_in_channel_safe(padded, bits, block_order, blocks_x)
del padded del padded
gc.collect() gc.collect()
result = _unpad_image(embedded, original_size) result = _unpad_image(embedded, original_size)
del embedded del embedded
gc.collect() gc.collect()
stego_bytes = _save_stego_image(result, output_format) stego_bytes = _save_stego_image(result, output_format)
del result del result
gc.collect() gc.collect()
stats = DCTEmbedStats( stats = DCTEmbedStats(
blocks_used=(len(bits) + len(DEFAULT_EMBED_POSITIONS) - 1) // len(DEFAULT_EMBED_POSITIONS), blocks_used=(len(bits) + len(DEFAULT_EMBED_POSITIONS) - 1) // len(DEFAULT_EMBED_POSITIONS),
blocks_available=capacity_info.total_blocks, blocks_available=capacity_info.total_blocks,
@@ -575,7 +574,7 @@ def _embed_scipy_dct_safe(
jpeg_native=False, jpeg_native=False,
color_mode=color_mode, color_mode=color_mode,
) )
return stego_bytes, stats return stego_bytes, stats
@@ -587,78 +586,78 @@ def _embed_in_channel_safe(
) -> np.ndarray: ) -> np.ndarray:
""" """
Embed bits in channel using safe DCT operations. Embed bits in channel using safe DCT operations.
Processes one block at a time with fresh array allocations. Processes one block at a time with fresh array allocations.
""" """
h, w = channel.shape h, w = channel.shape
# Create result with explicit new memory # Create result with explicit new memory
result = np.array(channel, dtype=np.float64, copy=True, order='C') result = np.array(channel, dtype=np.float64, copy=True, order='C')
bit_idx = 0 bit_idx = 0
for block_num in block_order: for block_num in block_order:
if bit_idx >= len(bits): if bit_idx >= len(bits):
break break
by = (block_num // blocks_x) * BLOCK_SIZE by = (block_num // blocks_x) * BLOCK_SIZE
bx = (block_num % blocks_x) * BLOCK_SIZE bx = (block_num % blocks_x) * BLOCK_SIZE
# Extract block - create brand new array # Extract block - create brand new array
block = np.array( block = np.array(
result[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE], result[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE],
dtype=np.float64, copy=True, order='C' dtype=np.float64, copy=True, order='C'
) )
# Apply safe DCT (row-by-row) # Apply safe DCT (row-by-row)
dct_block = _safe_dct2(block) dct_block = _safe_dct2(block)
# Embed bits # Embed bits
for pos in DEFAULT_EMBED_POSITIONS: for pos in DEFAULT_EMBED_POSITIONS:
if bit_idx >= len(bits): if bit_idx >= len(bits):
break break
dct_block[pos[0], pos[1]] = _embed_bit_in_coeff( dct_block[pos[0], pos[1]] = _embed_bit_in_coeff(
float(dct_block[pos[0], pos[1]]), float(dct_block[pos[0], pos[1]]),
bits[bit_idx] bits[bit_idx]
) )
bit_idx += 1 bit_idx += 1
# Apply safe inverse DCT # Apply safe inverse DCT
modified_block = _safe_idct2(dct_block) modified_block = _safe_idct2(dct_block)
# Copy back # Copy back
result[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE] = modified_block result[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE] = modified_block
# Clean up this iteration # Clean up this iteration
del block, dct_block, modified_block del block, dct_block, modified_block
# Force garbage collection # Force garbage collection
gc.collect() gc.collect()
return result return result
def _normalize_jpeg_for_jpegio(image_data: bytes) -> bytes: def _normalize_jpeg_for_jpegio(image_data: bytes) -> bytes:
""" """
Normalize a JPEG image to ensure jpegio can process it safely. Normalize a JPEG image to ensure jpegio can process it safely.
JPEGs saved with quality=100 have quantization tables with all values = 1, JPEGs saved with quality=100 have quantization tables with all values = 1,
which causes jpegio to crash due to huge coefficient magnitudes. which causes jpegio to crash due to huge coefficient magnitudes.
This function detects such images and re-saves them at a safe quality level. This function detects such images and re-saves them at a safe quality level.
Args: Args:
image_data: Raw JPEG bytes image_data: Raw JPEG bytes
Returns: Returns:
Normalized JPEG bytes (may be unchanged if already safe) Normalized JPEG bytes (may be unchanged if already safe)
""" """
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
# Only process JPEGs # Only process JPEGs
if img.format != 'JPEG': if img.format != 'JPEG':
img.close() img.close()
return image_data return image_data
# Check quantization tables # Check quantization tables
needs_normalization = False needs_normalization = False
if hasattr(img, 'quantization') and img.quantization: if hasattr(img, 'quantization') and img.quantization:
@@ -667,19 +666,19 @@ def _normalize_jpeg_for_jpegio(image_data: bytes) -> bytes:
if max(table) <= JPEGIO_MAX_QUANT_VALUE_THRESHOLD: if max(table) <= JPEGIO_MAX_QUANT_VALUE_THRESHOLD:
needs_normalization = True needs_normalization = True
break break
if not needs_normalization: if not needs_normalization:
img.close() img.close()
return image_data return image_data
# Re-save at safe quality level # Re-save at safe quality level
if img.mode != 'RGB': if img.mode != 'RGB':
img = img.convert('RGB') img = img.convert('RGB')
buffer = io.BytesIO() buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=JPEGIO_NORMALIZE_QUALITY, subsampling=0) img.save(buffer, format='JPEG', quality=JPEGIO_NORMALIZE_QUALITY, subsampling=0)
img.close() img.close()
return buffer.getvalue() return buffer.getvalue()
@@ -688,17 +687,17 @@ def _embed_jpegio(
carrier_image: bytes, carrier_image: bytes,
seed: bytes, seed: bytes,
color_mode: str = 'color', color_mode: str = 'color',
) -> Tuple[bytes, DCTEmbedStats]: ) -> tuple[bytes, DCTEmbedStats]:
"""Embed using jpegio for proper JPEG coefficient modification.""" """Embed using jpegio for proper JPEG coefficient modification."""
import tempfile
import os import os
import tempfile
# Normalize JPEG to avoid crashes with quality=100 images # Normalize JPEG to avoid crashes with quality=100 images
carrier_image = _normalize_jpeg_for_jpegio(carrier_image) carrier_image = _normalize_jpeg_for_jpegio(carrier_image)
img = Image.open(io.BytesIO(carrier_image)) img = Image.open(io.BytesIO(carrier_image))
width, height = img.size width, height = img.size
if img.format != 'JPEG': if img.format != 'JPEG':
buffer = io.BytesIO() buffer = io.BytesIO()
if img.mode != 'RGB': if img.mode != 'RGB':
@@ -706,54 +705,54 @@ def _embed_jpegio(
img.save(buffer, format='JPEG', quality=95, subsampling=0) img.save(buffer, format='JPEG', quality=95, subsampling=0)
carrier_image = buffer.getvalue() carrier_image = buffer.getvalue()
img.close() img.close()
input_path = _jpegio_bytes_to_file(carrier_image, suffix='.jpg') input_path = _jpegio_bytes_to_file(carrier_image, suffix='.jpg')
output_path = tempfile.mktemp(suffix='.jpg') output_path = tempfile.mktemp(suffix='.jpg')
flags = FLAG_COLOR_MODE if color_mode == 'color' else 0 flags = FLAG_COLOR_MODE if color_mode == 'color' else 0
try: try:
jpeg = jio.read(input_path) jpeg = jio.read(input_path)
coef_array = jpeg.coef_arrays[JPEGIO_EMBED_CHANNEL] coef_array = jpeg.coef_arrays[JPEGIO_EMBED_CHANNEL]
all_positions = _jpegio_get_usable_positions(coef_array) all_positions = _jpegio_get_usable_positions(coef_array)
order = _jpegio_generate_order(len(all_positions), seed) order = _jpegio_generate_order(len(all_positions), seed)
header = _jpegio_create_header(len(data), flags) header = _jpegio_create_header(len(data), flags)
payload = header + data payload = header + data
bits = [] bits = []
for byte in payload: for byte in payload:
for i in range(7, -1, -1): for i in range(7, -1, -1):
bits.append((byte >> i) & 1) bits.append((byte >> i) & 1)
if len(bits) > len(all_positions): if len(bits) > len(all_positions):
raise ValueError( raise ValueError(
f"Payload too large: {len(bits)} bits, " f"Payload too large: {len(bits)} bits, "
f"only {len(all_positions)} usable coefficients" f"only {len(all_positions)} usable coefficients"
) )
coefs_used = 0 coefs_used = 0
for bit_idx, pos_idx in enumerate(order): for bit_idx, pos_idx in enumerate(order):
if bit_idx >= len(bits): if bit_idx >= len(bits):
break break
row, col = all_positions[pos_idx] row, col = all_positions[pos_idx]
coef = coef_array[row, col] coef = coef_array[row, col]
if (coef & 1) != bits[bit_idx]: if (coef & 1) != bits[bit_idx]:
if coef > 0: if coef > 0:
coef_array[row, col] = coef - 1 if (coef & 1) else coef + 1 coef_array[row, col] = coef - 1 if (coef & 1) else coef + 1
else: else:
coef_array[row, col] = coef + 1 if (coef & 1) else coef - 1 coef_array[row, col] = coef + 1 if (coef & 1) else coef - 1
coefs_used += 1 coefs_used += 1
jio.write(jpeg, output_path) jio.write(jpeg, output_path)
with open(output_path, 'rb') as f: with open(output_path, 'rb') as f:
stego_bytes = f.read() stego_bytes = f.read()
stats = DCTEmbedStats( stats = DCTEmbedStats(
blocks_used=coefs_used // 63, blocks_used=coefs_used // 63,
blocks_available=len(all_positions) // 63, blocks_available=len(all_positions) // 63,
@@ -766,9 +765,9 @@ def _embed_jpegio(
jpeg_native=True, jpeg_native=True,
color_mode=color_mode, color_mode=color_mode,
) )
return stego_bytes, stats return stego_bytes, stats
finally: finally:
for path in [input_path, output_path]: for path in [input_path, output_path]:
try: try:
@@ -782,13 +781,13 @@ def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes:
img = Image.open(io.BytesIO(stego_image)) img = Image.open(io.BytesIO(stego_image))
fmt = img.format fmt = img.format
img.close() img.close()
if fmt == 'JPEG' and HAS_JPEGIO: if fmt == 'JPEG' and HAS_JPEGIO:
try: try:
return _extract_jpegio(stego_image, seed) return _extract_jpegio(stego_image, seed)
except ValueError: except ValueError:
pass pass
_check_scipy() _check_scipy()
return _extract_scipy_dct_safe(stego_image, seed) return _extract_scipy_dct_safe(stego_image, seed)
@@ -798,41 +797,41 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
img = Image.open(io.BytesIO(stego_image)) img = Image.open(io.BytesIO(stego_image))
width, height = img.size width, height = img.size
mode = img.mode mode = img.mode
if mode in ('RGB', 'RGBA'): if mode in ('RGB', 'RGBA'):
channel = _extract_y_channel(stego_image) channel = _extract_y_channel(stego_image)
else: else:
channel = _to_grayscale(stego_image) channel = _to_grayscale(stego_image)
img.close() img.close()
padded, _ = _pad_to_blocks(channel) padded, _ = _pad_to_blocks(channel)
del channel del channel
gc.collect() gc.collect()
h, w = padded.shape h, w = padded.shape
blocks_x = w // BLOCK_SIZE blocks_x = w // BLOCK_SIZE
num_blocks = (h // BLOCK_SIZE) * blocks_x num_blocks = (h // BLOCK_SIZE) * blocks_x
block_order = _generate_block_order(num_blocks, seed) block_order = _generate_block_order(num_blocks, seed)
all_bits = [] all_bits = []
for block_num in block_order: for block_num in block_order:
by = (block_num // blocks_x) * BLOCK_SIZE by = (block_num // blocks_x) * BLOCK_SIZE
bx = (block_num % blocks_x) * BLOCK_SIZE bx = (block_num % blocks_x) * BLOCK_SIZE
block = np.array( block = np.array(
padded[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE], padded[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE],
dtype=np.float64, copy=True, order='C' dtype=np.float64, copy=True, order='C'
) )
dct_block = _safe_dct2(block) dct_block = _safe_dct2(block)
for pos in DEFAULT_EMBED_POSITIONS: for pos in DEFAULT_EMBED_POSITIONS:
bit = _extract_bit_from_coeff(float(dct_block[pos[0], pos[1]])) bit = _extract_bit_from_coeff(float(dct_block[pos[0], pos[1]]))
all_bits.append(bit) all_bits.append(bit)
del block, dct_block del block, dct_block
if len(all_bits) >= HEADER_SIZE * 8: if len(all_bits) >= HEADER_SIZE * 8:
try: try:
_, flags, data_length = _parse_header(all_bits[:HEADER_SIZE * 8]) _, flags, data_length = _parse_header(all_bits[:HEADER_SIZE * 8])
@@ -841,53 +840,53 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
break break
except ValueError: except ValueError:
pass pass
del padded del padded
gc.collect() gc.collect()
_, flags, data_length = _parse_header(all_bits) _, flags, data_length = _parse_header(all_bits)
data_bits = all_bits[HEADER_SIZE * 8:(HEADER_SIZE + data_length) * 8] data_bits = all_bits[HEADER_SIZE * 8:(HEADER_SIZE + data_length) * 8]
data = bytes([ data = bytes([
sum(data_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8)) sum(data_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
for i in range(data_length) for i in range(data_length)
]) ])
return data return data
def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes: def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
"""Extract using jpegio for JPEG images.""" """Extract using jpegio for JPEG images."""
import os import os
# Normalize JPEG to avoid crashes with quality=100 images # Normalize JPEG to avoid crashes with quality=100 images
# (shouldn't happen with stego images, but be defensive) # (shouldn't happen with stego images, but be defensive)
stego_image = _normalize_jpeg_for_jpegio(stego_image) stego_image = _normalize_jpeg_for_jpegio(stego_image)
temp_path = _jpegio_bytes_to_file(stego_image, suffix='.jpg') temp_path = _jpegio_bytes_to_file(stego_image, suffix='.jpg')
try: try:
jpeg = jio.read(temp_path) jpeg = jio.read(temp_path)
coef_array = jpeg.coef_arrays[JPEGIO_EMBED_CHANNEL] coef_array = jpeg.coef_arrays[JPEGIO_EMBED_CHANNEL]
all_positions = _jpegio_get_usable_positions(coef_array) all_positions = _jpegio_get_usable_positions(coef_array)
order = _jpegio_generate_order(len(all_positions), seed) order = _jpegio_generate_order(len(all_positions), seed)
header_bits = [] header_bits = []
for pos_idx in order[:HEADER_SIZE * 8]: for pos_idx in order[:HEADER_SIZE * 8]:
row, col = all_positions[pos_idx] row, col = all_positions[pos_idx]
coef = coef_array[row, col] coef = coef_array[row, col]
header_bits.append(coef & 1) header_bits.append(coef & 1)
header_bytes = bytes([ header_bytes = bytes([
sum(header_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8)) sum(header_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
for i in range(HEADER_SIZE) for i in range(HEADER_SIZE)
]) ])
_, flags, data_length = _jpegio_parse_header(header_bytes) _, flags, data_length = _jpegio_parse_header(header_bytes)
total_bits_needed = (HEADER_SIZE + data_length) * 8 total_bits_needed = (HEADER_SIZE + data_length) * 8
all_bits = [] all_bits = []
for bit_idx, pos_idx in enumerate(order): for bit_idx, pos_idx in enumerate(order):
if bit_idx >= total_bits_needed: if bit_idx >= total_bits_needed:
@@ -895,16 +894,16 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
row, col = all_positions[pos_idx] row, col = all_positions[pos_idx]
coef = coef_array[row, col] coef = coef_array[row, col]
all_bits.append(coef & 1) all_bits.append(coef & 1)
data_bits = all_bits[HEADER_SIZE * 8:] data_bits = all_bits[HEADER_SIZE * 8:]
data = bytes([ data = bytes([
sum(data_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8)) sum(data_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
for i in range(data_length) for i in range(data_length)
]) ])
return data return data
finally: finally:
try: try:
os.unlink(temp_path) os.unlink(temp_path)

View File

@@ -5,12 +5,13 @@ Debugging, logging, and performance monitoring tools.
Can be disabled for production use. Can be disabled for production use.
""" """
import sys
import time import time
import traceback import traceback
from collections.abc import Callable
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Callable, Any, Optional, Dict, Union from typing import Any
import sys
# Global debug configuration # Global debug configuration
DEBUG_ENABLED = False # Set to True to enable debug output DEBUG_ENABLED = False # Set to True to enable debug output
@@ -47,10 +48,10 @@ def debug_data(data: bytes, label: str = "Data", max_bytes: int = 32) -> str:
"""Format bytes for debugging.""" """Format bytes for debugging."""
if not DEBUG_ENABLED: if not DEBUG_ENABLED:
return "" return ""
if not data: if not data:
return f"{label}: Empty" return f"{label}: Empty"
if len(data) <= max_bytes: if len(data) <= max_bytes:
return f"{label} ({len(data)} bytes): {data.hex()}" return f"{label} ({len(data)} bytes): {data.hex()}"
else: else:
@@ -71,7 +72,7 @@ def time_function(func: Callable) -> Callable:
def wrapper(*args, **kwargs) -> Any: def wrapper(*args, **kwargs) -> Any:
if not (DEBUG_ENABLED and LOG_PERFORMANCE): if not (DEBUG_ENABLED and LOG_PERFORMANCE):
return func(*args, **kwargs) return func(*args, **kwargs)
start = time.perf_counter() start = time.perf_counter()
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
@@ -79,7 +80,7 @@ def time_function(func: Callable) -> Callable:
finally: finally:
end = time.perf_counter() end = time.perf_counter()
debug_print(f"{func.__name__} took {end - start:.6f}s", "PERF") debug_print(f"{func.__name__} took {end - start:.6f}s", "PERF")
return wrapper return wrapper
@@ -89,14 +90,15 @@ def validate_assertion(condition: bool, message: str) -> None:
raise AssertionError(f"Validation failed: {message}") raise AssertionError(f"Validation failed: {message}")
def memory_usage() -> Dict[str, Union[float, str]]: def memory_usage() -> dict[str, float | str]:
"""Get current memory usage (if psutil is available).""" """Get current memory usage (if psutil is available)."""
try: try:
import psutil
import os import os
import psutil
process = psutil.Process(os.getpid()) process = psutil.Process(os.getpid())
mem_info = process.memory_info() mem_info = process.memory_info()
return { return {
'rss_mb': mem_info.rss / 1024 / 1024, 'rss_mb': mem_info.rss / 1024 / 1024,
'vms_mb': mem_info.vms / 1024 / 1024, 'vms_mb': mem_info.vms / 1024 / 1024,
@@ -110,66 +112,66 @@ def hexdump(data: bytes, offset: int = 0, length: int = 64) -> str:
"""Create hexdump string for debugging binary data.""" """Create hexdump string for debugging binary data."""
if not data: if not data:
return "Empty" return "Empty"
result = [] result = []
data_to_dump = data[:length] data_to_dump = data[:length]
for i in range(0, len(data_to_dump), 16): for i in range(0, len(data_to_dump), 16):
chunk = data_to_dump[i:i+16] chunk = data_to_dump[i:i+16]
hex_str = ' '.join(f'{b:02x}' for b in chunk) hex_str = ' '.join(f'{b:02x}' for b in chunk)
hex_str = hex_str.ljust(47) hex_str = hex_str.ljust(47)
ascii_str = ''.join(chr(b) if 32 <= b < 127 else '.' for b in chunk) ascii_str = ''.join(chr(b) if 32 <= b < 127 else '.' for b in chunk)
result.append(f"{offset + i:08x}: {hex_str} {ascii_str}") result.append(f"{offset + i:08x}: {hex_str} {ascii_str}")
if len(data) > length: if len(data) > length:
result.append(f"... ({len(data) - length} more bytes)") result.append(f"... ({len(data) - length} more bytes)")
return '\n'.join(result) return '\n'.join(result)
class Debug: class Debug:
"""Debugging utility class.""" """Debugging utility class."""
def __init__(self): def __init__(self):
self.enabled = DEBUG_ENABLED self.enabled = DEBUG_ENABLED
def print(self, message: str, level: str = "INFO") -> None: def print(self, message: str, level: str = "INFO") -> None:
"""Print debug message.""" """Print debug message."""
debug_print(message, level) debug_print(message, level)
def data(self, data: bytes, label: str = "Data", max_bytes: int = 32) -> str: def data(self, data: bytes, label: str = "Data", max_bytes: int = 32) -> str:
"""Format bytes for debugging.""" """Format bytes for debugging."""
return debug_data(data, label, max_bytes) return debug_data(data, label, max_bytes)
def exception(self, e: Exception, context: str = "") -> None: def exception(self, e: Exception, context: str = "") -> None:
"""Log exception with context.""" """Log exception with context."""
debug_exception(e, context) debug_exception(e, context)
def time(self, func: Callable) -> Callable: def time(self, func: Callable) -> Callable:
"""Decorator to time function execution.""" """Decorator to time function execution."""
return time_function(func) return time_function(func)
def validate(self, condition: bool, message: str) -> None: def validate(self, condition: bool, message: str) -> None:
"""Runtime validation assertion.""" """Runtime validation assertion."""
validate_assertion(condition, message) validate_assertion(condition, message)
def memory(self) -> Dict[str, Union[float, str]]: def memory(self) -> dict[str, float | str]:
"""Get current memory usage.""" """Get current memory usage."""
return memory_usage() return memory_usage()
def hexdump(self, data: bytes, offset: int = 0, length: int = 64) -> str: def hexdump(self, data: bytes, offset: int = 0, length: int = 64) -> str:
"""Create hexdump string.""" """Create hexdump string."""
return hexdump(data, offset, length) return hexdump(data, offset, length)
def enable(self, enable: bool = True) -> None: def enable(self, enable: bool = True) -> None:
"""Enable or disable debug mode.""" """Enable or disable debug mode."""
enable_debug(enable) enable_debug(enable)
self.enabled = enable self.enabled = enable
def enable_performance(self, enable: bool = True) -> None: def enable_performance(self, enable: bool = True) -> None:
"""Enable or disable performance logging.""" """Enable or disable performance logging."""
enable_performance_logging(enable) enable_performance_logging(enable)
def enable_assertions(self, enable: bool = True) -> None: def enable_assertions(self, enable: bool = True) -> None:
"""Enable or disable validation assertions.""" """Enable or disable validation assertions."""
enable_assertions(enable) enable_assertions(enable)

View File

@@ -8,21 +8,20 @@ Changes in v4.0.0:
- Improved error messages for channel key mismatches - Improved error messages for channel key mismatches
""" """
from typing import Optional, Union
from pathlib import Path from pathlib import Path
from .models import DecodeInput, DecodeResult from .constants import EMBED_MODE_AUTO
from .crypto import decrypt_message from .crypto import decrypt_message
from .debug import debug
from .exceptions import DecryptionError, ExtractionError
from .models import DecodeResult
from .steganography import extract_from_image from .steganography import extract_from_image
from .validation import ( from .validation import (
require_valid_image,
require_security_factors, require_security_factors,
require_valid_image,
require_valid_pin, require_valid_pin,
require_valid_rsa_key, require_valid_rsa_key,
) )
from .constants import EMBED_MODE_AUTO
from .exceptions import ExtractionError, DecryptionError
from .debug import debug
def decode( def decode(
@@ -30,14 +29,14 @@ def decode(
reference_photo: bytes, reference_photo: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> DecodeResult: ) -> DecodeResult:
""" """
Decode a message or file from a stego image. Decode a message or file from a stego image.
Args: Args:
stego_image: Stego image bytes stego_image: Stego image bytes
reference_photo: Shared reference photo bytes reference_photo: Shared reference photo bytes
@@ -50,10 +49,10 @@ def decode(
- None or "auto": Use server's configured key - None or "auto": Use server's configured key
- str: Use this specific channel key - str: Use this specific channel key
- "" or False: No channel key (public mode) - "" or False: No channel key (public mode)
Returns: Returns:
DecodeResult with message or file data DecodeResult with message or file data
Example: Example:
>>> result = decode( >>> result = decode(
... stego_image=stego_bytes, ... stego_image=stego_bytes,
@@ -66,7 +65,7 @@ def decode(
... else: ... else:
... with open(result.filename, 'wb') as f: ... with open(result.filename, 'wb') as f:
... f.write(result.file_data) ... f.write(result.file_data)
Example with explicit channel key: Example with explicit channel key:
>>> result = decode( >>> result = decode(
... stego_image=stego_bytes, ... stego_image=stego_bytes,
@@ -79,41 +78,41 @@ def decode(
debug.print(f"decode: passphrase length={len(passphrase.split())} words, " debug.print(f"decode: passphrase length={len(passphrase.split())} words, "
f"mode={embed_mode}, " f"mode={embed_mode}, "
f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}") f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}")
# Validate inputs # Validate inputs
require_valid_image(stego_image, "Stego image") require_valid_image(stego_image, "Stego image")
require_valid_image(reference_photo, "Reference photo") require_valid_image(reference_photo, "Reference photo")
require_security_factors(pin, rsa_key_data) require_security_factors(pin, rsa_key_data)
if pin: if pin:
require_valid_pin(pin) require_valid_pin(pin)
if rsa_key_data: if rsa_key_data:
require_valid_rsa_key(rsa_key_data, rsa_password) require_valid_rsa_key(rsa_key_data, rsa_password)
# Derive pixel/coefficient selection key (with channel key) # Derive pixel/coefficient selection key (with channel key)
from .crypto import derive_pixel_key from .crypto import derive_pixel_key
pixel_key = derive_pixel_key( pixel_key = derive_pixel_key(
reference_photo, passphrase, pin, rsa_key_data, channel_key reference_photo, passphrase, pin, rsa_key_data, channel_key
) )
# Extract encrypted data # Extract encrypted data
encrypted = extract_from_image( encrypted = extract_from_image(
stego_image, stego_image,
pixel_key, pixel_key,
embed_mode=embed_mode, embed_mode=embed_mode,
) )
if not encrypted: if not encrypted:
debug.print("No data extracted from image") debug.print("No data extracted from image")
raise ExtractionError("Could not extract data. Check your credentials and image.") raise ExtractionError("Could not extract data. Check your credentials and image.")
debug.print(f"Extracted {len(encrypted)} bytes from image") debug.print(f"Extracted {len(encrypted)} bytes from image")
# Decrypt (with channel key) # Decrypt (with channel key)
result = decrypt_message( result = decrypt_message(
encrypted, reference_photo, passphrase, pin, rsa_key_data, channel_key encrypted, reference_photo, passphrase, pin, rsa_key_data, channel_key
) )
debug.print(f"Decryption successful: {result.payload_type}") debug.print(f"Decryption successful: {result.payload_type}")
return result return result
@@ -122,16 +121,16 @@ def decode_file(
stego_image: bytes, stego_image: bytes,
reference_photo: bytes, reference_photo: bytes,
passphrase: str, passphrase: str,
output_path: Optional[Path] = None, output_path: Path | None = None,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> Path: ) -> Path:
""" """
Decode a file from a stego image and save it. Decode a file from a stego image and save it.
Args: Args:
stego_image: Stego image bytes stego_image: Stego image bytes
reference_photo: Shared reference photo bytes reference_photo: Shared reference photo bytes
@@ -142,10 +141,10 @@ def decode_file(
rsa_password: Optional RSA key password rsa_password: Optional RSA key password
embed_mode: 'auto', 'lsb', or 'dct' embed_mode: 'auto', 'lsb', or 'dct'
channel_key: Channel key parameter (see decode()) channel_key: Channel key parameter (see decode())
Returns: Returns:
Path where file was saved Path where file was saved
Raises: Raises:
DecryptionError: If payload is text, not a file DecryptionError: If payload is text, not a file
""" """
@@ -159,20 +158,20 @@ def decode_file(
embed_mode, embed_mode,
channel_key, channel_key,
) )
if not result.is_file: if not result.is_file:
raise DecryptionError("Payload is a text message, not a file") raise DecryptionError("Payload is a text message, not a file")
if output_path is None: if output_path is None:
output_path = Path(result.filename or "extracted_file") output_path = Path(result.filename or "extracted_file")
else: else:
output_path = Path(output_path) output_path = Path(output_path)
if output_path.is_dir(): if output_path.is_dir():
output_path = output_path / (result.filename or "extracted_file") output_path = output_path / (result.filename or "extracted_file")
# Write file # Write file
output_path.write_bytes(result.file_data or b"") output_path.write_bytes(result.file_data or b"")
debug.print(f"File saved to: {output_path}") debug.print(f"File saved to: {output_path}")
return output_path return output_path
@@ -182,16 +181,16 @@ def decode_text(
reference_photo: bytes, reference_photo: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> str: ) -> str:
""" """
Decode a text message from a stego image. Decode a text message from a stego image.
Convenience function that returns just the message string. Convenience function that returns just the message string.
Args: Args:
stego_image: Stego image bytes stego_image: Stego image bytes
reference_photo: Shared reference photo bytes reference_photo: Shared reference photo bytes
@@ -201,10 +200,10 @@ def decode_text(
rsa_password: Optional RSA key password rsa_password: Optional RSA key password
embed_mode: 'auto', 'lsb', or 'dct' embed_mode: 'auto', 'lsb', or 'dct'
channel_key: Channel key parameter (see decode()) channel_key: Channel key parameter (see decode())
Returns: Returns:
Decoded message string Decoded message string
Raises: Raises:
DecryptionError: If payload is a file, not text DecryptionError: If payload is a file, not text
""" """
@@ -218,7 +217,7 @@ def decode_text(
embed_mode, embed_mode,
channel_key, channel_key,
) )
if result.is_file: if result.is_file:
# Try to decode as text # Try to decode as text
if result.file_data: if result.file_data:
@@ -229,5 +228,5 @@ def decode_text(
f"Payload is a binary file ({result.filename or 'unnamed'}), not text" f"Payload is a binary file ({result.filename or 'unnamed'}), not text"
) )
return "" return ""
return result.message or "" return result.message or ""

View File

@@ -7,41 +7,40 @@ Changes in v4.0.0:
- Added channel_key parameter for deployment/group isolation - Added channel_key parameter for deployment/group isolation
""" """
from typing import Optional, Union
from pathlib import Path from pathlib import Path
from .models import EncodeInput, EncodeResult, FilePayload from .constants import EMBED_MODE_LSB
from .crypto import encrypt_message, derive_pixel_key from .crypto import derive_pixel_key, encrypt_message
from .debug import debug
from .models import EncodeResult, FilePayload
from .steganography import embed_in_image from .steganography import embed_in_image
from .utils import generate_filename
from .validation import ( from .validation import (
require_valid_payload,
require_valid_image,
require_security_factors, require_security_factors,
require_valid_image,
require_valid_payload,
require_valid_pin, require_valid_pin,
require_valid_rsa_key, require_valid_rsa_key,
) )
from .utils import generate_filename
from .constants import EMBED_MODE_LSB
from .debug import debug
def encode( def encode(
message: Union[str, bytes, FilePayload], message: str | bytes | FilePayload,
reference_photo: bytes, reference_photo: bytes,
carrier_image: bytes, carrier_image: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
output_format: Optional[str] = None, output_format: str | None = None,
embed_mode: str = EMBED_MODE_LSB, embed_mode: str = EMBED_MODE_LSB,
dct_output_format: str = "png", dct_output_format: str = "png",
dct_color_mode: str = "grayscale", dct_color_mode: str = "grayscale",
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> EncodeResult: ) -> EncodeResult:
""" """
Encode a message or file into an image. Encode a message or file into an image.
Args: Args:
message: Text message, raw bytes, or FilePayload to hide message: Text message, raw bytes, or FilePayload to hide
reference_photo: Shared reference photo bytes reference_photo: Shared reference photo bytes
@@ -58,10 +57,10 @@ def encode(
- None or "auto": Use server's configured key - None or "auto": Use server's configured key
- str: Use this specific channel key - str: Use this specific channel key
- "" or False: No channel key (public mode) - "" or False: No channel key (public mode)
Returns: Returns:
EncodeResult with stego image and metadata EncodeResult with stego image and metadata
Example: Example:
>>> result = encode( >>> result = encode(
... message="Secret message", ... message="Secret message",
@@ -72,7 +71,7 @@ def encode(
... ) ... )
>>> with open('stego.png', 'wb') as f: >>> with open('stego.png', 'wb') as f:
... f.write(result.stego_image) ... f.write(result.stego_image)
Example with explicit channel key: Example with explicit channel key:
>>> result = encode( >>> result = encode(
... message="Secret message", ... message="Secret message",
@@ -86,30 +85,30 @@ def encode(
debug.print(f"encode: passphrase length={len(passphrase.split())} words, " debug.print(f"encode: passphrase length={len(passphrase.split())} words, "
f"pin={'set' if pin else 'none'}, mode={embed_mode}, " f"pin={'set' if pin else 'none'}, mode={embed_mode}, "
f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}") f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}")
# Validate inputs # Validate inputs
require_valid_payload(message) require_valid_payload(message)
require_valid_image(reference_photo, "Reference photo") require_valid_image(reference_photo, "Reference photo")
require_valid_image(carrier_image, "Carrier image") require_valid_image(carrier_image, "Carrier image")
require_security_factors(pin, rsa_key_data) require_security_factors(pin, rsa_key_data)
if pin: if pin:
require_valid_pin(pin) require_valid_pin(pin)
if rsa_key_data: if rsa_key_data:
require_valid_rsa_key(rsa_key_data, rsa_password) require_valid_rsa_key(rsa_key_data, rsa_password)
# Encrypt message (with channel key) # Encrypt message (with channel key)
encrypted = encrypt_message( encrypted = encrypt_message(
message, reference_photo, passphrase, pin, rsa_key_data, channel_key message, reference_photo, passphrase, pin, rsa_key_data, channel_key
) )
debug.print(f"Encrypted payload: {len(encrypted)} bytes") debug.print(f"Encrypted payload: {len(encrypted)} bytes")
# Derive pixel/coefficient selection key (with channel key) # Derive pixel/coefficient selection key (with channel key)
pixel_key = derive_pixel_key( pixel_key = derive_pixel_key(
reference_photo, passphrase, pin, rsa_key_data, channel_key reference_photo, passphrase, pin, rsa_key_data, channel_key
) )
# Embed in image # Embed in image
stego_data, stats, extension = embed_in_image( stego_data, stats, extension = embed_in_image(
encrypted, encrypted,
@@ -120,10 +119,10 @@ def encode(
dct_output_format=dct_output_format, dct_output_format=dct_output_format,
dct_color_mode=dct_color_mode, dct_color_mode=dct_color_mode,
) )
# Generate filename # Generate filename
filename = generate_filename(extension=extension) filename = generate_filename(extension=extension)
# Create result # Create result
if hasattr(stats, 'pixels_modified'): if hasattr(stats, 'pixels_modified'):
# LSB mode stats # LSB mode stats
@@ -148,25 +147,25 @@ def encode(
def encode_file( def encode_file(
filepath: Union[str, Path], filepath: str | Path,
reference_photo: bytes, reference_photo: bytes,
carrier_image: bytes, carrier_image: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
output_format: Optional[str] = None, output_format: str | None = None,
filename_override: Optional[str] = None, filename_override: str | None = None,
embed_mode: str = EMBED_MODE_LSB, embed_mode: str = EMBED_MODE_LSB,
dct_output_format: str = "png", dct_output_format: str = "png",
dct_color_mode: str = "grayscale", dct_color_mode: str = "grayscale",
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> EncodeResult: ) -> EncodeResult:
""" """
Encode a file into an image. Encode a file into an image.
Convenience wrapper that loads a file and encodes it. Convenience wrapper that loads a file and encodes it.
Args: Args:
filepath: Path to file to embed filepath: Path to file to embed
reference_photo: Shared reference photo bytes reference_photo: Shared reference photo bytes
@@ -181,12 +180,12 @@ def encode_file(
dct_output_format: 'png' or 'jpeg' dct_output_format: 'png' or 'jpeg'
dct_color_mode: 'grayscale' or 'color' dct_color_mode: 'grayscale' or 'color'
channel_key: Channel key parameter (see encode()) channel_key: Channel key parameter (see encode())
Returns: Returns:
EncodeResult EncodeResult
""" """
payload = FilePayload.from_file(str(filepath), filename_override) payload = FilePayload.from_file(str(filepath), filename_override)
return encode( return encode(
message=payload, message=payload,
reference_photo=reference_photo, reference_photo=reference_photo,
@@ -210,18 +209,18 @@ def encode_bytes(
carrier_image: bytes, carrier_image: bytes,
passphrase: str, passphrase: str,
pin: str = "", pin: str = "",
rsa_key_data: Optional[bytes] = None, rsa_key_data: bytes | None = None,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
output_format: Optional[str] = None, output_format: str | None = None,
mime_type: Optional[str] = None, mime_type: str | None = None,
embed_mode: str = EMBED_MODE_LSB, embed_mode: str = EMBED_MODE_LSB,
dct_output_format: str = "png", dct_output_format: str = "png",
dct_color_mode: str = "grayscale", dct_color_mode: str = "grayscale",
channel_key: Optional[Union[str, bool]] = None, channel_key: str | bool | None = None,
) -> EncodeResult: ) -> EncodeResult:
""" """
Encode raw bytes with metadata into an image. Encode raw bytes with metadata into an image.
Args: Args:
data: Raw bytes to embed data: Raw bytes to embed
filename: Filename to associate with data filename: Filename to associate with data
@@ -237,12 +236,12 @@ def encode_bytes(
dct_output_format: 'png' or 'jpeg' dct_output_format: 'png' or 'jpeg'
dct_color_mode: 'grayscale' or 'color' dct_color_mode: 'grayscale' or 'color'
channel_key: Channel key parameter (see encode()) channel_key: Channel key parameter (see encode())
Returns: Returns:
EncodeResult EncodeResult
""" """
payload = FilePayload(data=data, filename=filename, mime_type=mime_type) payload = FilePayload(data=data, filename=filename, mime_type=mime_type)
return encode( return encode(
message=payload, message=payload,
reference_photo=reference_photo, reference_photo=reference_photo,

View File

@@ -89,7 +89,7 @@ class SteganographyError(StegasooError):
class CapacityError(SteganographyError): class CapacityError(SteganographyError):
"""Carrier image too small for message.""" """Carrier image too small for message."""
def __init__(self, needed: int, available: int): def __init__(self, needed: int, available: int):
self.needed = needed self.needed = needed
self.available = available self.available = available
@@ -129,7 +129,7 @@ class FileNotFoundError(FileError):
class FileTooLargeError(FileError): class FileTooLargeError(FileError):
"""File exceeds size limit.""" """File exceeds size limit."""
def __init__(self, size: int, limit: int, filename: str = "File"): def __init__(self, size: int, limit: int, filename: str = "File"):
self.size = size self.size = size
self.limit = limit self.limit = limit
@@ -141,7 +141,7 @@ class FileTooLargeError(FileError):
class UnsupportedFileTypeError(FileError): class UnsupportedFileTypeError(FileError):
"""File type not supported.""" """File type not supported."""
def __init__(self, extension: str, allowed: set[str]): def __init__(self, extension: str, allowed: set[str]):
self.extension = extension self.extension = extension
self.allowed = allowed self.allowed = allowed

View File

@@ -4,28 +4,30 @@ Stegasoo Generate Module (v3.2.0)
Public API for generating credentials (PINs, passphrases, RSA keys). Public API for generating credentials (PINs, passphrases, RSA keys).
""" """
from typing import Optional
from .keygen import (
generate_pin as _generate_pin,
generate_phrase,
generate_rsa_key as _generate_rsa_key,
export_rsa_key_pem,
load_rsa_key,
)
from .models import Credentials
from .constants import ( from .constants import (
DEFAULT_PIN_LENGTH,
DEFAULT_PASSPHRASE_WORDS, DEFAULT_PASSPHRASE_WORDS,
DEFAULT_PIN_LENGTH,
DEFAULT_RSA_BITS, DEFAULT_RSA_BITS,
) )
from .debug import debug from .debug import debug
from .keygen import (
export_rsa_key_pem,
generate_phrase,
load_rsa_key,
)
from .keygen import (
generate_pin as _generate_pin,
)
from .keygen import (
generate_rsa_key as _generate_rsa_key,
)
from .models import Credentials
# Re-export from keygen for convenience # Re-export from keygen for convenience
__all__ = [ __all__ = [
'generate_pin', 'generate_pin',
'generate_passphrase', 'generate_passphrase',
'generate_rsa_key', 'generate_rsa_key',
'generate_credentials', 'generate_credentials',
'export_rsa_key_pem', 'export_rsa_key_pem',
@@ -36,15 +38,15 @@ __all__ = [
def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str: def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
""" """
Generate a random PIN. Generate a random PIN.
PINs never start with zero for usability. PINs never start with zero for usability.
Args: Args:
length: PIN length (6-9 digits, default 6) length: PIN length (6-9 digits, default 6)
Returns: Returns:
PIN string PIN string
Example: Example:
>>> pin = generate_pin() >>> pin = generate_pin()
>>> len(pin) >>> len(pin)
@@ -58,16 +60,16 @@ def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
def generate_passphrase(words: int = DEFAULT_PASSPHRASE_WORDS) -> str: def generate_passphrase(words: int = DEFAULT_PASSPHRASE_WORDS) -> str:
""" """
Generate a random passphrase from BIP-39 wordlist. Generate a random passphrase from BIP-39 wordlist.
In v3.2.0, this generates a single passphrase (not daily phrases). In v3.2.0, this generates a single passphrase (not daily phrases).
Default is 4 words for good security (increased from 3 in v3.1.0). Default is 4 words for good security (increased from 3 in v3.1.0).
Args: Args:
words: Number of words (3-12, default 4) words: Number of words (3-12, default 4)
Returns: Returns:
Space-separated passphrase Space-separated passphrase
Example: Example:
>>> passphrase = generate_passphrase(4) >>> passphrase = generate_passphrase(4)
>>> len(passphrase.split()) >>> len(passphrase.split())
@@ -78,18 +80,18 @@ def generate_passphrase(words: int = DEFAULT_PASSPHRASE_WORDS) -> str:
def generate_rsa_key( def generate_rsa_key(
bits: int = DEFAULT_RSA_BITS, bits: int = DEFAULT_RSA_BITS,
password: Optional[str] = None password: str | None = None
) -> str: ) -> str:
""" """
Generate an RSA private key in PEM format. Generate an RSA private key in PEM format.
Args: Args:
bits: Key size (2048, 3072, or 4096, default 2048) bits: Key size (2048, 3072, or 4096, default 2048)
password: Optional password to encrypt the key password: Optional password to encrypt the key
Returns: Returns:
PEM-encoded key string PEM-encoded key string
Example: Example:
>>> key_pem = generate_rsa_key(2048) >>> key_pem = generate_rsa_key(2048)
>>> '-----BEGIN PRIVATE KEY-----' in key_pem >>> '-----BEGIN PRIVATE KEY-----' in key_pem
@@ -106,14 +108,14 @@ def generate_credentials(
pin_length: int = DEFAULT_PIN_LENGTH, pin_length: int = DEFAULT_PIN_LENGTH,
rsa_bits: int = DEFAULT_RSA_BITS, rsa_bits: int = DEFAULT_RSA_BITS,
passphrase_words: int = DEFAULT_PASSPHRASE_WORDS, passphrase_words: int = DEFAULT_PASSPHRASE_WORDS,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
) -> Credentials: ) -> Credentials:
""" """
Generate a complete set of credentials. Generate a complete set of credentials.
In v3.2.0, this generates a single passphrase (not daily phrases). In v3.2.0, this generates a single passphrase (not daily phrases).
At least one of use_pin or use_rsa must be True. At least one of use_pin or use_rsa must be True.
Args: Args:
use_pin: Whether to generate a PIN use_pin: Whether to generate a PIN
use_rsa: Whether to generate an RSA key use_rsa: Whether to generate an RSA key
@@ -121,13 +123,13 @@ def generate_credentials(
rsa_bits: RSA key size (default 2048) rsa_bits: RSA key size (default 2048)
passphrase_words: Number of words in passphrase (default 4) passphrase_words: Number of words in passphrase (default 4)
rsa_password: Optional password for RSA key rsa_password: Optional password for RSA key
Returns: Returns:
Credentials object with passphrase, PIN, and/or RSA key Credentials object with passphrase, PIN, and/or RSA key
Raises: Raises:
ValueError: If neither PIN nor RSA is selected ValueError: If neither PIN nor RSA is selected
Example: Example:
>>> creds = generate_credentials(use_pin=True, use_rsa=False) >>> creds = generate_credentials(use_pin=True, use_rsa=False)
>>> len(creds.passphrase.split()) >>> len(creds.passphrase.split())
@@ -137,23 +139,23 @@ def generate_credentials(
""" """
if not use_pin and not use_rsa: if not use_pin and not use_rsa:
raise ValueError("Must select at least one security factor (PIN or RSA key)") raise ValueError("Must select at least one security factor (PIN or RSA key)")
debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, " debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, "
f"passphrase_words={passphrase_words}") f"passphrase_words={passphrase_words}")
# Generate passphrase (single, not daily) # Generate passphrase (single, not daily)
passphrase = generate_phrase(passphrase_words) passphrase = generate_phrase(passphrase_words)
# Generate PIN if requested # Generate PIN if requested
pin = _generate_pin(pin_length) if use_pin else None pin = _generate_pin(pin_length) if use_pin else None
# Generate RSA key if requested # Generate RSA key if requested
rsa_key_pem = None rsa_key_pem = None
if use_rsa: if use_rsa:
rsa_key_obj = _generate_rsa_key(rsa_bits) rsa_key_obj = _generate_rsa_key(rsa_bits)
rsa_key_bytes = export_rsa_key_pem(rsa_key_obj, rsa_password) rsa_key_bytes = export_rsa_key_pem(rsa_key_obj, rsa_password)
rsa_key_pem = rsa_key_bytes.decode('utf-8') rsa_key_pem = rsa_key_bytes.decode('utf-8')
# Create Credentials object (v3.2.0 format) # Create Credentials object (v3.2.0 format)
creds = Credentials( creds = Credentials(
passphrase=passphrase, passphrase=passphrase,
@@ -162,6 +164,6 @@ def generate_credentials(
rsa_bits=rsa_bits if use_rsa else None, rsa_bits=rsa_bits if use_rsa else None,
words_per_passphrase=passphrase_words, words_per_passphrase=passphrase_words,
) )
debug.print(f"Credentials generated: {creds.total_entropy} bits total entropy") debug.print(f"Credentials generated: {creds.total_entropy} bits total entropy")
return creds return creds

View File

@@ -4,40 +4,40 @@ Stegasoo Image Utilities (v3.2.0)
Functions for analyzing images and comparing capacity. Functions for analyzing images and comparing capacity.
""" """
from typing import Optional
import io import io
from PIL import Image from PIL import Image
from .models import ImageInfo, CapacityComparison from .constants import EMBED_MODE_LSB
from .steganography import calculate_capacity, has_dct_support
from .constants import EMBED_MODE_LSB, EMBED_MODE_DCT
from .debug import debug from .debug import debug
from .models import CapacityComparison, ImageInfo
from .steganography import calculate_capacity, has_dct_support
def get_image_info(image_data: bytes) -> ImageInfo: def get_image_info(image_data: bytes) -> ImageInfo:
""" """
Get detailed information about an image. Get detailed information about an image.
Args: Args:
image_data: Image file bytes image_data: Image file bytes
Returns: Returns:
ImageInfo with dimensions, format, capacity estimates ImageInfo with dimensions, format, capacity estimates
Example: Example:
>>> info = get_image_info(carrier_bytes) >>> info = get_image_info(carrier_bytes)
>>> print(f"{info.width}x{info.height}, {info.lsb_capacity_kb} KB capacity") >>> print(f"{info.width}x{info.height}, {info.lsb_capacity_kb} KB capacity")
""" """
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
width, height = img.size width, height = img.size
pixels = width * height pixels = width * height
format_str = img.format or "Unknown" format_str = img.format or "Unknown"
mode = img.mode mode = img.mode
# Calculate LSB capacity # Calculate LSB capacity
lsb_capacity = calculate_capacity(image_data, bits_per_channel=1) lsb_capacity = calculate_capacity(image_data, bits_per_channel=1)
# Calculate DCT capacity if available # Calculate DCT capacity if available
dct_capacity = None dct_capacity = None
if has_dct_support(): if has_dct_support():
@@ -47,7 +47,7 @@ def get_image_info(image_data: bytes) -> ImageInfo:
dct_capacity = dct_info.usable_capacity_bytes dct_capacity = dct_info.usable_capacity_bytes
except Exception as e: except Exception as e:
debug.print(f"Could not calculate DCT capacity: {e}") debug.print(f"Could not calculate DCT capacity: {e}")
info = ImageInfo( info = ImageInfo(
width=width, width=width,
height=height, height=height,
@@ -60,27 +60,27 @@ def get_image_info(image_data: bytes) -> ImageInfo:
dct_capacity_bytes=dct_capacity, dct_capacity_bytes=dct_capacity,
dct_capacity_kb=dct_capacity / 1024 if dct_capacity else None, dct_capacity_kb=dct_capacity / 1024 if dct_capacity else None,
) )
debug.print(f"Image info: {width}x{height}, LSB={lsb_capacity} bytes, " debug.print(f"Image info: {width}x{height}, LSB={lsb_capacity} bytes, "
f"DCT={dct_capacity or 'N/A'} bytes") f"DCT={dct_capacity or 'N/A'} bytes")
return info return info
def compare_capacity( def compare_capacity(
carrier_image: bytes, carrier_image: bytes,
reference_photo: Optional[bytes] = None, reference_photo: bytes | None = None,
) -> CapacityComparison: ) -> CapacityComparison:
""" """
Compare embedding capacity between LSB and DCT modes. Compare embedding capacity between LSB and DCT modes.
Args: Args:
carrier_image: Carrier image bytes carrier_image: Carrier image bytes
reference_photo: Optional reference photo (not used in v3.2.0, kept for API compatibility) reference_photo: Optional reference photo (not used in v3.2.0, kept for API compatibility)
Returns: Returns:
CapacityComparison with capacity info for both modes CapacityComparison with capacity info for both modes
Example: Example:
>>> comparison = compare_capacity(carrier_bytes) >>> comparison = compare_capacity(carrier_bytes)
>>> print(f"LSB: {comparison.lsb_kb:.1f} KB") >>> print(f"LSB: {comparison.lsb_kb:.1f} KB")
@@ -88,16 +88,16 @@ def compare_capacity(
""" """
img = Image.open(io.BytesIO(carrier_image)) img = Image.open(io.BytesIO(carrier_image))
width, height = img.size width, height = img.size
# LSB capacity # LSB capacity
lsb_bytes = calculate_capacity(carrier_image, bits_per_channel=1) lsb_bytes = calculate_capacity(carrier_image, bits_per_channel=1)
lsb_kb = lsb_bytes / 1024 lsb_kb = lsb_bytes / 1024
# DCT capacity # DCT capacity
dct_available = has_dct_support() dct_available = has_dct_support()
dct_bytes = None dct_bytes = None
dct_kb = None dct_kb = None
if dct_available: if dct_available:
try: try:
from .dct_steganography import calculate_dct_capacity from .dct_steganography import calculate_dct_capacity
@@ -107,7 +107,7 @@ def compare_capacity(
except Exception as e: except Exception as e:
debug.print(f"DCT capacity calculation failed: {e}") debug.print(f"DCT capacity calculation failed: {e}")
dct_available = False dct_available = False
comparison = CapacityComparison( comparison = CapacityComparison(
image_width=width, image_width=width,
image_height=height, image_height=height,
@@ -121,9 +121,9 @@ def compare_capacity(
dct_output_formats=["PNG (grayscale)", "JPEG (grayscale)"] if dct_available else None, dct_output_formats=["PNG (grayscale)", "JPEG (grayscale)"] if dct_available else None,
dct_ratio_vs_lsb=(dct_bytes / lsb_bytes * 100) if dct_bytes else None, dct_ratio_vs_lsb=(dct_bytes / lsb_bytes * 100) if dct_bytes else None,
) )
debug.print(f"Capacity comparison: LSB={lsb_kb:.1f}KB, DCT={dct_kb or 'N/A'}KB") debug.print(f"Capacity comparison: LSB={lsb_kb:.1f}KB, DCT={dct_kb or 'N/A'}KB")
return comparison return comparison
@@ -134,27 +134,27 @@ def validate_carrier_capacity(
) -> dict: ) -> dict:
""" """
Check if a payload will fit in a carrier image. Check if a payload will fit in a carrier image.
Args: Args:
carrier_image: Carrier image bytes carrier_image: Carrier image bytes
payload_size: Size of payload in bytes payload_size: Size of payload in bytes
embed_mode: 'lsb' or 'dct' embed_mode: 'lsb' or 'dct'
Returns: Returns:
Dict with 'fits', 'capacity', 'usage_percent', 'headroom' Dict with 'fits', 'capacity', 'usage_percent', 'headroom'
""" """
from .steganography import calculate_capacity_by_mode from .steganography import calculate_capacity_by_mode
capacity_info = calculate_capacity_by_mode(carrier_image, embed_mode) capacity_info = calculate_capacity_by_mode(carrier_image, embed_mode)
capacity = capacity_info['capacity_bytes'] capacity = capacity_info['capacity_bytes']
# Add encryption overhead estimate # Add encryption overhead estimate
estimated_size = payload_size + 200 # Approximate overhead estimated_size = payload_size + 200 # Approximate overhead
fits = estimated_size <= capacity fits = estimated_size <= capacity
usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0 usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0
headroom = capacity - estimated_size headroom = capacity - estimated_size
return { return {
'fits': fits, 'fits': fits,
'capacity': capacity, 'capacity': capacity,

View File

@@ -10,53 +10,57 @@ Changes in v3.2.0:
""" """
import secrets import secrets
from typing import Optional, Dict, Union
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.backends import default_backend
from .constants import ( from .constants import (
DAY_NAMES, DAY_NAMES,
MIN_PIN_LENGTH, MAX_PIN_LENGTH, DEFAULT_PIN_LENGTH, DEFAULT_PASSPHRASE_WORDS,
MIN_PASSPHRASE_WORDS, MAX_PASSPHRASE_WORDS, DEFAULT_PASSPHRASE_WORDS, DEFAULT_PIN_LENGTH,
MIN_RSA_BITS, VALID_RSA_SIZES, DEFAULT_RSA_BITS, DEFAULT_RSA_BITS,
MAX_PASSPHRASE_WORDS,
MAX_PIN_LENGTH,
MIN_PASSPHRASE_WORDS,
MIN_PIN_LENGTH,
VALID_RSA_SIZES,
get_wordlist, get_wordlist,
) )
from .models import Credentials, KeyInfo
from .exceptions import KeyGenerationError, KeyPasswordError
from .debug import debug from .debug import debug
from .exceptions import KeyGenerationError, KeyPasswordError
from .models import Credentials, KeyInfo
def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str: def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
""" """
Generate a random PIN. Generate a random PIN.
PINs never start with zero for usability. PINs never start with zero for usability.
Args: Args:
length: PIN length (6-9 digits) length: PIN length (6-9 digits)
Returns: Returns:
PIN string PIN string
Example: Example:
>>> generate_pin(6) >>> generate_pin(6)
"812345" "812345"
""" """
debug.validate(MIN_PIN_LENGTH <= length <= MAX_PIN_LENGTH, debug.validate(MIN_PIN_LENGTH <= length <= MAX_PIN_LENGTH,
f"PIN length must be between {MIN_PIN_LENGTH} and {MAX_PIN_LENGTH}") f"PIN length must be between {MIN_PIN_LENGTH} and {MAX_PIN_LENGTH}")
length = max(MIN_PIN_LENGTH, min(MAX_PIN_LENGTH, length)) length = max(MIN_PIN_LENGTH, min(MAX_PIN_LENGTH, length))
# First digit: 1-9 (no leading zero) # First digit: 1-9 (no leading zero)
first_digit = str(secrets.randbelow(9) + 1) first_digit = str(secrets.randbelow(9) + 1)
# Remaining digits: 0-9 # Remaining digits: 0-9
rest = ''.join(str(secrets.randbelow(10)) for _ in range(length - 1)) rest = ''.join(str(secrets.randbelow(10)) for _ in range(length - 1))
pin = first_digit + rest pin = first_digit + rest
debug.print(f"Generated PIN: {pin}") debug.print(f"Generated PIN: {pin}")
return pin return pin
@@ -65,23 +69,23 @@ def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
def generate_phrase(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> str: def generate_phrase(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> str:
""" """
Generate a random passphrase from BIP-39 wordlist. Generate a random passphrase from BIP-39 wordlist.
Args: Args:
words_per_phrase: Number of words (3-12) words_per_phrase: Number of words (3-12)
Returns: Returns:
Space-separated phrase Space-separated phrase
Example: Example:
>>> generate_phrase(4) >>> generate_phrase(4)
"apple forest thunder mountain" "apple forest thunder mountain"
""" """
debug.validate(MIN_PASSPHRASE_WORDS <= words_per_phrase <= MAX_PASSPHRASE_WORDS, debug.validate(MIN_PASSPHRASE_WORDS <= words_per_phrase <= MAX_PASSPHRASE_WORDS,
f"Words per phrase must be between {MIN_PASSPHRASE_WORDS} and {MAX_PASSPHRASE_WORDS}") f"Words per phrase must be between {MIN_PASSPHRASE_WORDS} and {MAX_PASSPHRASE_WORDS}")
words_per_phrase = max(MIN_PASSPHRASE_WORDS, min(MAX_PASSPHRASE_WORDS, words_per_phrase)) words_per_phrase = max(MIN_PASSPHRASE_WORDS, min(MAX_PASSPHRASE_WORDS, words_per_phrase))
wordlist = get_wordlist() wordlist = get_wordlist()
words = [secrets.choice(wordlist) for _ in range(words_per_phrase)] words = [secrets.choice(wordlist) for _ in range(words_per_phrase)]
phrase = ' '.join(words) phrase = ' '.join(words)
debug.print(f"Generated phrase: {phrase}") debug.print(f"Generated phrase: {phrase}")
@@ -92,19 +96,19 @@ def generate_phrase(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> str:
generate_passphrase = generate_phrase generate_passphrase = generate_phrase
def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> Dict[str, str]: def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> dict[str, str]:
""" """
Generate phrases for all days of the week. Generate phrases for all days of the week.
DEPRECATED in v3.2.0: Use generate_phrase() for single passphrase. DEPRECATED in v3.2.0: Use generate_phrase() for single passphrase.
Kept for legacy compatibility and organizational use cases. Kept for legacy compatibility and organizational use cases.
Args: Args:
words_per_phrase: Number of words per phrase (3-12) words_per_phrase: Number of words per phrase (3-12)
Returns: Returns:
Dict mapping day names to phrases Dict mapping day names to phrases
Example: Example:
>>> generate_day_phrases(3) >>> generate_day_phrases(3)
{'Monday': 'apple forest thunder', 'Tuesday': 'banana river lightning', ...} {'Monday': 'apple forest thunder', 'Tuesday': 'banana river lightning', ...}
@@ -116,7 +120,7 @@ def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> Di
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2
) )
phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES} phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES}
debug.print(f"Generated phrases for {len(phrases)} days") debug.print(f"Generated phrases for {len(phrases)} days")
return phrases return phrases
@@ -125,16 +129,16 @@ def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> Di
def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey: def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
""" """
Generate an RSA private key. Generate an RSA private key.
Args: Args:
bits: Key size (2048, 3072, or 4096) bits: Key size (2048, 3072, or 4096)
Returns: Returns:
RSA private key object RSA private key object
Raises: Raises:
KeyGenerationError: If generation fails KeyGenerationError: If generation fails
Example: Example:
>>> key = generate_rsa_key(2048) >>> key = generate_rsa_key(2048)
>>> key.key_size >>> key.key_size
@@ -142,10 +146,10 @@ def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
""" """
debug.validate(bits in VALID_RSA_SIZES, debug.validate(bits in VALID_RSA_SIZES,
f"RSA key size must be one of {VALID_RSA_SIZES}") f"RSA key size must be one of {VALID_RSA_SIZES}")
if bits not in VALID_RSA_SIZES: if bits not in VALID_RSA_SIZES:
bits = DEFAULT_RSA_BITS bits = DEFAULT_RSA_BITS
debug.print(f"Generating {bits}-bit RSA key...") debug.print(f"Generating {bits}-bit RSA key...")
try: try:
key = rsa.generate_private_key( key = rsa.generate_private_key(
@@ -162,18 +166,18 @@ def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
def export_rsa_key_pem( def export_rsa_key_pem(
private_key: rsa.RSAPrivateKey, private_key: rsa.RSAPrivateKey,
password: Optional[str] = None password: str | None = None
) -> bytes: ) -> bytes:
""" """
Export RSA key to PEM format. Export RSA key to PEM format.
Args: Args:
private_key: RSA private key object private_key: RSA private key object
password: Optional password for encryption password: Optional password for encryption
Returns: Returns:
PEM-encoded key bytes PEM-encoded key bytes
Example: Example:
>>> key = generate_rsa_key() >>> key = generate_rsa_key()
>>> pem = export_rsa_key_pem(key) >>> pem = export_rsa_key_pem(key)
@@ -181,19 +185,16 @@ def export_rsa_key_pem(
b'-----BEGIN PRIVATE KEY-----\\nMIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYw' b'-----BEGIN PRIVATE KEY-----\\nMIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYw'
""" """
debug.validate(private_key is not None, "Private key cannot be None") debug.validate(private_key is not None, "Private key cannot be None")
encryption_algorithm: Union[ encryption_algorithm: serialization.BestAvailableEncryption | serialization.NoEncryption
serialization.BestAvailableEncryption,
serialization.NoEncryption
]
if password: if password:
encryption_algorithm = serialization.BestAvailableEncryption(password.encode()) encryption_algorithm = serialization.BestAvailableEncryption(password.encode())
debug.print("Exporting RSA key with encryption") debug.print("Exporting RSA key with encryption")
else: else:
encryption_algorithm = serialization.NoEncryption() encryption_algorithm = serialization.NoEncryption()
debug.print("Exporting RSA key without encryption") debug.print("Exporting RSA key without encryption")
return private_key.private_bytes( return private_key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8, format=serialization.PrivateFormat.PKCS8,
@@ -203,39 +204,39 @@ def export_rsa_key_pem(
def load_rsa_key( def load_rsa_key(
key_data: bytes, key_data: bytes,
password: Optional[str] = None password: str | None = None
) -> rsa.RSAPrivateKey: ) -> rsa.RSAPrivateKey:
""" """
Load RSA private key from PEM data. Load RSA private key from PEM data.
Args: Args:
key_data: PEM-encoded key bytes key_data: PEM-encoded key bytes
password: Password if key is encrypted password: Password if key is encrypted
Returns: Returns:
RSA private key object RSA private key object
Raises: Raises:
KeyPasswordError: If password is wrong or missing KeyPasswordError: If password is wrong or missing
KeyGenerationError: If key is invalid KeyGenerationError: If key is invalid
Example: Example:
>>> key = load_rsa_key(pem_data, "my_password") >>> key = load_rsa_key(pem_data, "my_password")
""" """
debug.validate(key_data is not None and len(key_data) > 0, debug.validate(key_data is not None and len(key_data) > 0,
"Key data cannot be empty") "Key data cannot be empty")
try: try:
pwd_bytes = password.encode() if password else None pwd_bytes = password.encode() if password else None
debug.print(f"Loading RSA key (encrypted: {bool(password)})") debug.print(f"Loading RSA key (encrypted: {bool(password)})")
key: PrivateKeyTypes = load_pem_private_key( key: PrivateKeyTypes = load_pem_private_key(
key_data, password=pwd_bytes, backend=default_backend() key_data, password=pwd_bytes, backend=default_backend()
) )
# Verify it's an RSA key # Verify it's an RSA key
if not isinstance(key, rsa.RSAPrivateKey): if not isinstance(key, rsa.RSAPrivateKey):
raise KeyGenerationError(f"Expected RSA key, got {type(key).__name__}") raise KeyGenerationError(f"Expected RSA key, got {type(key).__name__}")
debug.print(f"RSA key loaded: {key.key_size} bits") debug.print(f"RSA key loaded: {key.key_size} bits")
return key return key
except TypeError: except TypeError:
@@ -253,17 +254,17 @@ def load_rsa_key(
raise KeyGenerationError(f"Could not load RSA key: {e}") from e raise KeyGenerationError(f"Could not load RSA key: {e}") from e
def get_key_info(key_data: bytes, password: Optional[str] = None) -> KeyInfo: def get_key_info(key_data: bytes, password: str | None = None) -> KeyInfo:
""" """
Get information about an RSA key. Get information about an RSA key.
Args: Args:
key_data: PEM-encoded key bytes key_data: PEM-encoded key bytes
password: Password if key is encrypted password: Password if key is encrypted
Returns: Returns:
KeyInfo with key size and encryption status KeyInfo with key size and encryption status
Example: Example:
>>> info = get_key_info(pem_data) >>> info = get_key_info(pem_data)
>>> info.key_size >>> info.key_size
@@ -274,15 +275,15 @@ def get_key_info(key_data: bytes, password: Optional[str] = None) -> KeyInfo:
debug.print("Getting RSA key info") debug.print("Getting RSA key info")
# Check if encrypted # Check if encrypted
is_encrypted = b'ENCRYPTED' in key_data is_encrypted = b'ENCRYPTED' in key_data
private_key = load_rsa_key(key_data, password) private_key = load_rsa_key(key_data, password)
info = KeyInfo( info = KeyInfo(
key_size=private_key.key_size, key_size=private_key.key_size,
is_encrypted=is_encrypted, is_encrypted=is_encrypted,
pem_data=key_data pem_data=key_data
) )
debug.print(f"Key info: {info.key_size} bits, encrypted: {info.is_encrypted}") debug.print(f"Key info: {info.key_size} bits, encrypted: {info.is_encrypted}")
return info return info
@@ -293,14 +294,14 @@ def generate_credentials(
pin_length: int = DEFAULT_PIN_LENGTH, pin_length: int = DEFAULT_PIN_LENGTH,
rsa_bits: int = DEFAULT_RSA_BITS, rsa_bits: int = DEFAULT_RSA_BITS,
passphrase_words: int = DEFAULT_PASSPHRASE_WORDS, passphrase_words: int = DEFAULT_PASSPHRASE_WORDS,
rsa_password: Optional[str] = None, rsa_password: str | None = None,
) -> Credentials: ) -> Credentials:
""" """
Generate a complete set of credentials. Generate a complete set of credentials.
v3.2.0: Now generates a single passphrase instead of daily phrases. v3.2.0: Now generates a single passphrase instead of daily phrases.
At least one of use_pin or use_rsa must be True. At least one of use_pin or use_rsa must be True.
Args: Args:
use_pin: Whether to generate a PIN use_pin: Whether to generate a PIN
use_rsa: Whether to generate an RSA key use_rsa: Whether to generate an RSA key
@@ -308,13 +309,13 @@ def generate_credentials(
rsa_bits: RSA key size if generating (default 2048) rsa_bits: RSA key size if generating (default 2048)
passphrase_words: Words in passphrase (default 4) passphrase_words: Words in passphrase (default 4)
rsa_password: Optional password for RSA key encryption rsa_password: Optional password for RSA key encryption
Returns: Returns:
Credentials object with passphrase, PIN, and/or RSA key Credentials object with passphrase, PIN, and/or RSA key
Raises: Raises:
ValueError: If neither PIN nor RSA is selected ValueError: If neither PIN nor RSA is selected
Example: Example:
>>> creds = generate_credentials(use_pin=True, use_rsa=False) >>> creds = generate_credentials(use_pin=True, use_rsa=False)
>>> creds.passphrase >>> creds.passphrase
@@ -324,25 +325,25 @@ def generate_credentials(
""" """
debug.validate(use_pin or use_rsa, debug.validate(use_pin or use_rsa,
"Must select at least one security factor (PIN or RSA key)") "Must select at least one security factor (PIN or RSA key)")
if not use_pin and not use_rsa: if not use_pin and not use_rsa:
raise ValueError("Must select at least one security factor (PIN or RSA key)") raise ValueError("Must select at least one security factor (PIN or RSA key)")
debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, " debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, "
f"passphrase_words={passphrase_words}") f"passphrase_words={passphrase_words}")
# Generate single passphrase (v3.2.0 - no daily rotation) # Generate single passphrase (v3.2.0 - no daily rotation)
passphrase = generate_phrase(passphrase_words) passphrase = generate_phrase(passphrase_words)
# Generate PIN if requested # Generate PIN if requested
pin = generate_pin(pin_length) if use_pin else None pin = generate_pin(pin_length) if use_pin else None
# Generate RSA key if requested # Generate RSA key if requested
rsa_key_pem = None rsa_key_pem = None
if use_rsa: if use_rsa:
rsa_key_obj = generate_rsa_key(rsa_bits) rsa_key_obj = generate_rsa_key(rsa_bits)
rsa_key_pem = export_rsa_key_pem(rsa_key_obj, rsa_password).decode('utf-8') rsa_key_pem = export_rsa_key_pem(rsa_key_obj, rsa_password).decode('utf-8')
# Create Credentials object (v3.2.0 format with single passphrase) # Create Credentials object (v3.2.0 format with single passphrase)
creds = Credentials( creds = Credentials(
passphrase=passphrase, passphrase=passphrase,
@@ -351,7 +352,7 @@ def generate_credentials(
rsa_bits=rsa_bits if use_rsa else None, rsa_bits=rsa_bits if use_rsa else None,
words_per_passphrase=passphrase_words, words_per_passphrase=passphrase_words,
) )
debug.print(f"Credentials generated: {creds.total_entropy} bits total entropy") debug.print(f"Credentials generated: {creds.total_entropy} bits total entropy")
return creds return creds
@@ -369,19 +370,19 @@ def generate_credentials_legacy(
) -> dict: ) -> dict:
""" """
Generate credentials in legacy format (v3.1.0 style with daily phrases). Generate credentials in legacy format (v3.1.0 style with daily phrases).
DEPRECATED: Use generate_credentials() for v3.2.0 format. DEPRECATED: Use generate_credentials() for v3.2.0 format.
This function exists only for migration tools that need to work with This function exists only for migration tools that need to work with
old-format credentials. old-format credentials.
Args: Args:
use_pin: Whether to generate a PIN use_pin: Whether to generate a PIN
use_rsa: Whether to generate an RSA key use_rsa: Whether to generate an RSA key
pin_length: PIN length if generating pin_length: PIN length if generating
rsa_bits: RSA key size if generating rsa_bits: RSA key size if generating
words_per_phrase: Words per daily phrase words_per_phrase: Words per daily phrase
Returns: Returns:
Dict with 'phrases' (dict), 'pin', 'rsa_key_pem', etc. Dict with 'phrases' (dict), 'pin', 'rsa_key_pem', etc.
""" """
@@ -392,20 +393,20 @@ def generate_credentials_legacy(
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2
) )
if not use_pin and not use_rsa: if not use_pin and not use_rsa:
raise ValueError("Must select at least one security factor (PIN or RSA key)") raise ValueError("Must select at least one security factor (PIN or RSA key)")
# Generate daily phrases (old format) # Generate daily phrases (old format)
phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES} phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES}
pin = generate_pin(pin_length) if use_pin else None pin = generate_pin(pin_length) if use_pin else None
rsa_key_pem = None rsa_key_pem = None
if use_rsa: if use_rsa:
rsa_key_obj = generate_rsa_key(rsa_bits) rsa_key_obj = generate_rsa_key(rsa_bits)
rsa_key_pem = export_rsa_key_pem(rsa_key_obj).decode('utf-8') rsa_key_pem = export_rsa_key_pem(rsa_key_obj).decode('utf-8')
return { return {
'phrases': phrases, 'phrases': phrases,
'pin': pin, 'pin': pin,

View File

@@ -12,50 +12,48 @@ Changes in v3.2.0:
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date
from typing import Optional, Union, List
@dataclass @dataclass
class Credentials: class Credentials:
""" """
Generated credentials for encoding/decoding. Generated credentials for encoding/decoding.
v3.2.0: Simplified to use single passphrase instead of daily rotation. v3.2.0: Simplified to use single passphrase instead of daily rotation.
""" """
passphrase: str # Single passphrase (no daily rotation) passphrase: str # Single passphrase (no daily rotation)
pin: Optional[str] = None pin: str | None = None
rsa_key_pem: Optional[str] = None rsa_key_pem: str | None = None
rsa_bits: Optional[int] = None rsa_bits: int | None = None
words_per_passphrase: int = 4 # Increased from 3 in v3.1.0 words_per_passphrase: int = 4 # Increased from 3 in v3.1.0
# Optional: backup passphrases for multi-factor or rotation # Optional: backup passphrases for multi-factor or rotation
backup_passphrases: Optional[list[str]] = None backup_passphrases: list[str] | None = None
@property @property
def passphrase_entropy(self) -> int: def passphrase_entropy(self) -> int:
"""Entropy in bits from passphrase (~11 bits per BIP-39 word).""" """Entropy in bits from passphrase (~11 bits per BIP-39 word)."""
return self.words_per_passphrase * 11 return self.words_per_passphrase * 11
@property @property
def pin_entropy(self) -> int: def pin_entropy(self) -> int:
"""Entropy in bits from PIN (~3.32 bits per digit).""" """Entropy in bits from PIN (~3.32 bits per digit)."""
if self.pin: if self.pin:
return int(len(self.pin) * 3.32) return int(len(self.pin) * 3.32)
return 0 return 0
@property @property
def rsa_entropy(self) -> int: def rsa_entropy(self) -> int:
"""Effective entropy from RSA key.""" """Effective entropy from RSA key."""
if self.rsa_key_pem and self.rsa_bits: if self.rsa_key_pem and self.rsa_bits:
return min(self.rsa_bits // 16, 128) return min(self.rsa_bits // 16, 128)
return 0 return 0
@property @property
def total_entropy(self) -> int: def total_entropy(self) -> int:
"""Total entropy in bits (excluding reference photo).""" """Total entropy in bits (excluding reference photo)."""
return self.passphrase_entropy + self.pin_entropy + self.rsa_entropy return self.passphrase_entropy + self.pin_entropy + self.rsa_entropy
# Legacy property for compatibility # Legacy property for compatibility
@property @property
def phrase_entropy(self) -> int: def phrase_entropy(self) -> int:
@@ -68,23 +66,23 @@ class FilePayload:
"""Represents a file to be embedded.""" """Represents a file to be embedded."""
data: bytes data: bytes
filename: str filename: str
mime_type: Optional[str] = None mime_type: str | None = None
@property @property
def size(self) -> int: def size(self) -> int:
return len(self.data) return len(self.data)
@classmethod @classmethod
def from_file(cls, filepath: str, filename: Optional[str] = None) -> 'FilePayload': def from_file(cls, filepath: str, filename: str | None = None) -> 'FilePayload':
"""Create FilePayload from a file path.""" """Create FilePayload from a file path."""
from pathlib import Path
import mimetypes import mimetypes
from pathlib import Path
path = Path(filepath) path = Path(filepath)
data = path.read_bytes() data = path.read_bytes()
name = filename or path.name name = filename or path.name
mime, _ = mimetypes.guess_type(name) mime, _ = mimetypes.guess_type(name)
return cls(data=data, filename=name, mime_type=mime) return cls(data=data, filename=name, mime_type=mime)
@@ -92,23 +90,23 @@ class FilePayload:
class EncodeInput: class EncodeInput:
""" """
Input parameters for encoding a message. Input parameters for encoding a message.
v3.2.0: Removed date_str (date no longer used in crypto). v3.2.0: Removed date_str (date no longer used in crypto).
""" """
message: Union[str, bytes, FilePayload] # Text, raw bytes, or file message: str | bytes | FilePayload # Text, raw bytes, or file
reference_photo: bytes reference_photo: bytes
carrier_image: bytes carrier_image: bytes
passphrase: str # Renamed from day_phrase passphrase: str # Renamed from day_phrase
pin: str = "" pin: str = ""
rsa_key_data: Optional[bytes] = None rsa_key_data: bytes | None = None
rsa_password: Optional[str] = None rsa_password: str | None = None
@dataclass @dataclass
class EncodeResult: class EncodeResult:
""" """
Result of encoding operation. Result of encoding operation.
v3.2.0: date_used is now optional/cosmetic (not used in crypto). v3.2.0: date_used is now optional/cosmetic (not used in crypto).
""" """
stego_image: bytes stego_image: bytes
@@ -116,8 +114,8 @@ class EncodeResult:
pixels_modified: int pixels_modified: int
total_pixels: int total_pixels: int
capacity_used: float # 0.0 - 1.0 capacity_used: float # 0.0 - 1.0
date_used: Optional[str] = None # Cosmetic only (for filename organization) date_used: str | None = None # Cosmetic only (for filename organization)
@property @property
def capacity_percent(self) -> float: def capacity_percent(self) -> float:
"""Capacity used as percentage.""" """Capacity used as percentage."""
@@ -128,54 +126,54 @@ class EncodeResult:
class DecodeInput: class DecodeInput:
""" """
Input parameters for decoding a message. Input parameters for decoding a message.
v3.2.0: Renamed day_phrase → passphrase, no date needed. v3.2.0: Renamed day_phrase → passphrase, no date needed.
""" """
stego_image: bytes stego_image: bytes
reference_photo: bytes reference_photo: bytes
passphrase: str # Renamed from day_phrase passphrase: str # Renamed from day_phrase
pin: str = "" pin: str = ""
rsa_key_data: Optional[bytes] = None rsa_key_data: bytes | None = None
rsa_password: Optional[str] = None rsa_password: str | None = None
@dataclass @dataclass
class DecodeResult: class DecodeResult:
""" """
Result of decoding operation. Result of decoding operation.
v3.2.0: date_encoded is always None (date removed from crypto). v3.2.0: date_encoded is always None (date removed from crypto).
""" """
payload_type: str # 'text' or 'file' payload_type: str # 'text' or 'file'
message: Optional[str] = None # For text payloads message: str | None = None # For text payloads
file_data: Optional[bytes] = None # For file payloads file_data: bytes | None = None # For file payloads
filename: Optional[str] = None # Original filename for file payloads filename: str | None = None # Original filename for file payloads
mime_type: Optional[str] = None # MIME type hint mime_type: str | None = None # MIME type hint
date_encoded: Optional[str] = None # Always None in v3.2.0 (kept for compatibility) date_encoded: str | None = None # Always None in v3.2.0 (kept for compatibility)
@property @property
def is_file(self) -> bool: def is_file(self) -> bool:
return self.payload_type == 'file' return self.payload_type == 'file'
@property @property
def is_text(self) -> bool: def is_text(self) -> bool:
return self.payload_type == 'text' return self.payload_type == 'text'
def get_content(self) -> Union[str, bytes]: def get_content(self) -> str | bytes:
"""Get the decoded content (text or bytes).""" """Get the decoded content (text or bytes)."""
if self.is_text: if self.is_text:
return self.message or "" return self.message or ""
return self.file_data or b"" return self.file_data or b""
@dataclass @dataclass
class EmbedStats: class EmbedStats:
"""Statistics from image embedding.""" """Statistics from image embedding."""
pixels_modified: int pixels_modified: int
total_pixels: int total_pixels: int
capacity_used: float capacity_used: float
bytes_embedded: int bytes_embedded: int
@property @property
def modification_percent(self) -> float: def modification_percent(self) -> float:
"""Percentage of pixels modified.""" """Percentage of pixels modified."""
@@ -196,16 +194,16 @@ class ValidationResult:
is_valid: bool is_valid: bool
error_message: str = "" error_message: str = ""
details: dict = field(default_factory=dict) details: dict = field(default_factory=dict)
warning: Optional[str] = None # v3.2.0: Added for passphrase length warnings warning: str | None = None # v3.2.0: Added for passphrase length warnings
@classmethod @classmethod
def ok(cls, warning: Optional[str] = None, **details) -> 'ValidationResult': def ok(cls, warning: str | None = None, **details) -> 'ValidationResult':
"""Create a successful validation result.""" """Create a successful validation result."""
result = cls(is_valid=True, details=details) result = cls(is_valid=True, details=details)
if warning: if warning:
result.warning = warning result.warning = warning
return result return result
@classmethod @classmethod
def error(cls, message: str, **details) -> 'ValidationResult': def error(cls, message: str, **details) -> 'ValidationResult':
"""Create a failed validation result.""" """Create a failed validation result."""
@@ -227,8 +225,8 @@ class ImageInfo:
file_size: int file_size: int
lsb_capacity_bytes: int lsb_capacity_bytes: int
lsb_capacity_kb: float lsb_capacity_kb: float
dct_capacity_bytes: Optional[int] = None dct_capacity_bytes: int | None = None
dct_capacity_kb: Optional[float] = None dct_capacity_kb: float | None = None
@dataclass @dataclass
@@ -241,24 +239,24 @@ class CapacityComparison:
lsb_kb: float lsb_kb: float
lsb_output_format: str lsb_output_format: str
dct_available: bool dct_available: bool
dct_bytes: Optional[int] = None dct_bytes: int | None = None
dct_kb: Optional[float] = None dct_kb: float | None = None
dct_output_formats: Optional[List[str]] = None dct_output_formats: list[str] | None = None
dct_ratio_vs_lsb: Optional[float] = None dct_ratio_vs_lsb: float | None = None
@dataclass @dataclass
class GenerateResult: class GenerateResult:
"""Result of credential generation.""" """Result of credential generation."""
passphrase: str passphrase: str
pin: Optional[str] = None pin: str | None = None
rsa_key_pem: Optional[str] = None rsa_key_pem: str | None = None
passphrase_words: int = 4 passphrase_words: int = 4
passphrase_entropy: int = 0 passphrase_entropy: int = 0
pin_entropy: int = 0 pin_entropy: int = 0
rsa_entropy: int = 0 rsa_entropy: int = 0
total_entropy: int = 0 total_entropy: int = 0
def __str__(self) -> str: def __str__(self) -> str:
lines = [ lines = [
"Generated Credentials:", "Generated Credentials:",

View File

@@ -10,10 +10,9 @@ IMPROVEMENTS IN THIS VERSION:
- Improved error messages - Improved error messages
""" """
import base64
import io import io
import zlib import zlib
import base64
from typing import Optional, Tuple
from PIL import Image from PIL import Image
@@ -27,20 +26,19 @@ except ImportError:
# QR code reading # QR code reading
try: try:
from pyzbar.pyzbar import decode as pyzbar_decode
from pyzbar.pyzbar import ZBarSymbol from pyzbar.pyzbar import ZBarSymbol
from pyzbar.pyzbar import decode as pyzbar_decode
HAS_QRCODE_READ = True HAS_QRCODE_READ = True
except ImportError: except ImportError:
HAS_QRCODE_READ = False HAS_QRCODE_READ = False
from .constants import ( from .constants import (
QR_MAX_BINARY,
QR_CROP_PADDING_PERCENT,
QR_CROP_MIN_PADDING_PX, QR_CROP_MIN_PADDING_PX,
QR_CROP_PADDING_PERCENT,
QR_MAX_BINARY,
) )
# Constants # Constants
COMPRESSION_PREFIX = "STEGASOO-Z:" COMPRESSION_PREFIX = "STEGASOO-Z:"
@@ -48,10 +46,10 @@ COMPRESSION_PREFIX = "STEGASOO-Z:"
def compress_data(data: str) -> str: def compress_data(data: str) -> str:
""" """
Compress string data for QR code storage. Compress string data for QR code storage.
Args: Args:
data: String to compress data: String to compress
Returns: Returns:
Compressed string with STEGASOO-Z: prefix Compressed string with STEGASOO-Z: prefix
""" """
@@ -63,19 +61,19 @@ def compress_data(data: str) -> str:
def decompress_data(data: str) -> str: def decompress_data(data: str) -> str:
""" """
Decompress data from QR code. Decompress data from QR code.
Args: Args:
data: Compressed string with STEGASOO-Z: prefix data: Compressed string with STEGASOO-Z: prefix
Returns: Returns:
Original uncompressed string Original uncompressed string
Raises: Raises:
ValueError: If data is not valid compressed format ValueError: If data is not valid compressed format
""" """
if not data.startswith(COMPRESSION_PREFIX): if not data.startswith(COMPRESSION_PREFIX):
raise ValueError("Data is not in compressed format") raise ValueError("Data is not in compressed format")
encoded = data[len(COMPRESSION_PREFIX):] encoded = data[len(COMPRESSION_PREFIX):]
compressed = base64.b64decode(encoded) compressed = base64.b64decode(encoded)
return zlib.decompress(compressed).decode('utf-8') return zlib.decompress(compressed).decode('utf-8')
@@ -84,7 +82,7 @@ def decompress_data(data: str) -> str:
def normalize_pem(pem_data: str) -> str: def normalize_pem(pem_data: str) -> str:
""" """
Normalize PEM data to ensure proper formatting for cryptography library. Normalize PEM data to ensure proper formatting for cryptography library.
The cryptography library is very particular about PEM formatting. The cryptography library is very particular about PEM formatting.
This function handles all common issues from QR code extraction: This function handles all common issues from QR code extraction:
- Inconsistent line endings (CRLF, LF, CR) - Inconsistent line endings (CRLF, LF, CR)
@@ -93,24 +91,24 @@ def normalize_pem(pem_data: str) -> str:
- Non-ASCII characters - Non-ASCII characters
- Incorrect base64 padding - Incorrect base64 padding
- Malformed headers/footers - Malformed headers/footers
Args: Args:
pem_data: Raw PEM string from QR code pem_data: Raw PEM string from QR code
Returns: Returns:
Properly formatted PEM string that cryptography library will accept Properly formatted PEM string that cryptography library will accept
""" """
import re import re
# Step 1: Normalize ALL line endings to \n # Step 1: Normalize ALL line endings to \n
pem_data = pem_data.replace('\r\n', '\n').replace('\r', '\n') pem_data = pem_data.replace('\r\n', '\n').replace('\r', '\n')
# Step 2: Remove leading/trailing whitespace # Step 2: Remove leading/trailing whitespace
pem_data = pem_data.strip() pem_data = pem_data.strip()
# Step 3: Remove any non-ASCII characters (QR artifacts) # Step 3: Remove any non-ASCII characters (QR artifacts)
pem_data = ''.join(char for char in pem_data if ord(char) < 128) pem_data = ''.join(char for char in pem_data if ord(char) < 128)
# Step 4: Extract header, content, and footer with flexible regex # Step 4: Extract header, content, and footer with flexible regex
# This handles variations like: # This handles variations like:
# - "PRIVATE KEY" vs "RSA PRIVATE KEY" # - "PRIVATE KEY" vs "RSA PRIVATE KEY"
@@ -118,51 +116,51 @@ def normalize_pem(pem_data: str) -> str:
# - Missing spaces # - Missing spaces
pattern = r'(-----BEGIN[^-]*-----)(.*?)(-----END[^-]*-----)' pattern = r'(-----BEGIN[^-]*-----)(.*?)(-----END[^-]*-----)'
match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE) match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE)
if not match: if not match:
# Fallback: try even more permissive pattern # Fallback: try even more permissive pattern
pattern = r'(-+BEGIN[^-]+-+)(.*?)(-+END[^-]+-+)' pattern = r'(-+BEGIN[^-]+-+)(.*?)(-+END[^-]+-+)'
match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE) match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE)
if not match: if not match:
# Last resort: return original if can't parse # Last resort: return original if can't parse
return pem_data return pem_data
header_raw = match.group(1).strip() header_raw = match.group(1).strip()
content_raw = match.group(2) content_raw = match.group(2)
footer_raw = match.group(3).strip() footer_raw = match.group(3).strip()
# Step 5: Normalize header and footer # Step 5: Normalize header and footer
# Standardize spacing and ensure proper format # Standardize spacing and ensure proper format
header = re.sub(r'\s+', ' ', header_raw) header = re.sub(r'\s+', ' ', header_raw)
footer = re.sub(r'\s+', ' ', footer_raw) footer = re.sub(r'\s+', ' ', footer_raw)
# Ensure exactly 5 dashes on each side # Ensure exactly 5 dashes on each side
header = re.sub(r'^-+', '-----', header) header = re.sub(r'^-+', '-----', header)
header = re.sub(r'-+$', '-----', header) header = re.sub(r'-+$', '-----', header)
footer = re.sub(r'^-+', '-----', footer) footer = re.sub(r'^-+', '-----', footer)
footer = re.sub(r'-+$', '-----', footer) footer = re.sub(r'-+$', '-----', footer)
# Step 6: Clean the base64 content THOROUGHLY # Step 6: Clean the base64 content THOROUGHLY
# Remove ALL whitespace: spaces, tabs, newlines # Remove ALL whitespace: spaces, tabs, newlines
# Keep only valid base64 characters: A-Z, a-z, 0-9, +, /, = # Keep only valid base64 characters: A-Z, a-z, 0-9, +, /, =
content_clean = ''.join( content_clean = ''.join(
char for char in content_raw char for char in content_raw
if char.isalnum() or char in '+/=' if char.isalnum() or char in '+/='
) )
# Double-check: remove any remaining invalid characters # Double-check: remove any remaining invalid characters
content_clean = re.sub(r'[^A-Za-z0-9+/=]', '', content_clean) content_clean = re.sub(r'[^A-Za-z0-9+/=]', '', content_clean)
# Step 7: Fix base64 padding # Step 7: Fix base64 padding
# Base64 strings must be divisible by 4 # Base64 strings must be divisible by 4
remainder = len(content_clean) % 4 remainder = len(content_clean) % 4
if remainder: if remainder:
content_clean += '=' * (4 - remainder) content_clean += '=' * (4 - remainder)
# Step 8: Split into 64-character lines (PEM standard) # Step 8: Split into 64-character lines (PEM standard)
lines = [content_clean[i:i+64] for i in range(0, len(content_clean), 64)] lines = [content_clean[i:i+64] for i in range(0, len(content_clean), 64)]
# Step 9: Reconstruct with EXACT PEM formatting # Step 9: Reconstruct with EXACT PEM formatting
# Format: header\ncontent_line1\ncontent_line2\n...\nfooter\n # Format: header\ncontent_line1\ncontent_line2\n...\nfooter\n
return header + '\n' + '\n'.join(lines) + '\n' + footer + '\n' return header + '\n' + '\n'.join(lines) + '\n' + footer + '\n'
@@ -176,10 +174,10 @@ def is_compressed(data: str) -> bool:
def auto_decompress(data: str) -> str: def auto_decompress(data: str) -> str:
""" """
Automatically decompress data if compressed, otherwise return as-is. Automatically decompress data if compressed, otherwise return as-is.
Args: Args:
data: Possibly compressed string data: Possibly compressed string
Returns: Returns:
Decompressed string Decompressed string
""" """
@@ -196,11 +194,11 @@ def get_compressed_size(data: str) -> int:
def can_fit_in_qr(data: str, compress: bool = False) -> bool: def can_fit_in_qr(data: str, compress: bool = False) -> bool:
""" """
Check if data can fit in a QR code. Check if data can fit in a QR code.
Args: Args:
data: String data data: String data
compress: Whether compression will be used compress: Whether compression will be used
Returns: Returns:
True if data fits True if data fits
""" """
@@ -223,39 +221,39 @@ def generate_qr_code(
) -> bytes: ) -> bytes:
""" """
Generate a QR code PNG from string data. Generate a QR code PNG from string data.
Args: Args:
data: String data to encode data: String data to encode
compress: Whether to compress data first compress: Whether to compress data first
error_correction: QR error correction level (default: auto) error_correction: QR error correction level (default: auto)
Returns: Returns:
PNG image bytes PNG image bytes
Raises: Raises:
RuntimeError: If qrcode library not available RuntimeError: If qrcode library not available
ValueError: If data too large for QR code ValueError: If data too large for QR code
""" """
if not HAS_QRCODE_WRITE: if not HAS_QRCODE_WRITE:
raise RuntimeError("qrcode library not installed. Run: pip install qrcode[pil]") raise RuntimeError("qrcode library not installed. Run: pip install qrcode[pil]")
qr_data = data qr_data = data
# Compress if requested # Compress if requested
if compress: if compress:
qr_data = compress_data(data) qr_data = compress_data(data)
# Check size # Check size
if len(qr_data.encode('utf-8')) > QR_MAX_BINARY: if len(qr_data.encode('utf-8')) > QR_MAX_BINARY:
raise ValueError( raise ValueError(
f"Data too large for QR code ({len(qr_data)} bytes). " f"Data too large for QR code ({len(qr_data)} bytes). "
f"Maximum: {QR_MAX_BINARY} bytes" f"Maximum: {QR_MAX_BINARY} bytes"
) )
# Use lower error correction for larger data # Use lower error correction for larger data
if error_correction is None: if error_correction is None:
error_correction = ERROR_CORRECT_L if len(qr_data) > 1000 else ERROR_CORRECT_M error_correction = ERROR_CORRECT_L if len(qr_data) > 1000 else ERROR_CORRECT_M
qr = qrcode.QRCode( qr = qrcode.QRCode(
version=None, version=None,
error_correction=error_correction, error_correction=error_correction,
@@ -264,25 +262,25 @@ def generate_qr_code(
) )
qr.add_data(qr_data) qr.add_data(qr_data)
qr.make(fit=True) qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white") img = qr.make_image(fill_color="black", back_color="white")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format='PNG') img.save(buf, format='PNG')
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
def read_qr_code(image_data: bytes) -> Optional[str]: def read_qr_code(image_data: bytes) -> str | None:
""" """
Read QR code from image data. Read QR code from image data.
Args: Args:
image_data: Image bytes (PNG, JPG, etc.) image_data: Image bytes (PNG, JPG, etc.)
Returns: Returns:
Decoded string, or None if no QR code found Decoded string, or None if no QR code found
Raises: Raises:
RuntimeError: If pyzbar library not available RuntimeError: If pyzbar library not available
""" """
@@ -291,35 +289,35 @@ def read_qr_code(image_data: bytes) -> Optional[str]:
"pyzbar library not installed. Run: pip install pyzbar\n" "pyzbar library not installed. Run: pip install pyzbar\n"
"Also requires system library: sudo apt-get install libzbar0" "Also requires system library: sudo apt-get install libzbar0"
) )
try: try:
img: Image.Image = Image.open(io.BytesIO(image_data)) img: Image.Image = Image.open(io.BytesIO(image_data))
# Convert to RGB if necessary (pyzbar works best with RGB/grayscale) # Convert to RGB if necessary (pyzbar works best with RGB/grayscale)
if img.mode not in ('RGB', 'L'): if img.mode not in ('RGB', 'L'):
img = img.convert('RGB') img = img.convert('RGB')
# Decode QR codes # Decode QR codes
decoded = pyzbar_decode(img, symbols=[ZBarSymbol.QRCODE]) decoded = pyzbar_decode(img, symbols=[ZBarSymbol.QRCODE])
if not decoded: if not decoded:
return None return None
# Return first QR code found # Return first QR code found
result: str = decoded[0].data.decode('utf-8') result: str = decoded[0].data.decode('utf-8')
return result return result
except Exception: except Exception:
return None return None
def read_qr_code_from_file(filepath: str) -> Optional[str]: def read_qr_code_from_file(filepath: str) -> str | None:
""" """
Read QR code from image file. Read QR code from image file.
Args: Args:
filepath: Path to image file filepath: Path to image file
Returns: Returns:
Decoded string, or None if no QR code found Decoded string, or None if no QR code found
""" """
@@ -327,25 +325,25 @@ def read_qr_code_from_file(filepath: str) -> Optional[str]:
return read_qr_code(f.read()) return read_qr_code(f.read())
def extract_key_from_qr(image_data: bytes) -> Optional[str]: def extract_key_from_qr(image_data: bytes) -> str | None:
""" """
Extract RSA key from QR code image, auto-decompressing if needed. Extract RSA key from QR code image, auto-decompressing if needed.
This function is more robust than the original, with better error handling This function is more robust than the original, with better error handling
and PEM normalization. and PEM normalization.
Args: Args:
image_data: Image bytes containing QR code image_data: Image bytes containing QR code
Returns: Returns:
PEM-encoded RSA key string, or None if not found/invalid PEM-encoded RSA key string, or None if not found/invalid
""" """
# Step 1: Read QR code # Step 1: Read QR code
qr_data = read_qr_code(image_data) qr_data = read_qr_code(image_data)
if not qr_data: if not qr_data:
return None return None
# Step 2: Auto-decompress if needed # Step 2: Auto-decompress if needed
try: try:
if is_compressed(qr_data): if is_compressed(qr_data):
@@ -355,11 +353,11 @@ def extract_key_from_qr(image_data: bytes) -> Optional[str]:
except Exception: except Exception:
# If decompression fails, try using data as-is # If decompression fails, try using data as-is
key_pem = qr_data key_pem = qr_data
# Step 3: Validate it looks like a PEM key # Step 3: Validate it looks like a PEM key
if '-----BEGIN' not in key_pem or '-----END' not in key_pem: if '-----BEGIN' not in key_pem or '-----END' not in key_pem:
return None return None
# Step 4: Aggressively normalize PEM format # Step 4: Aggressively normalize PEM format
# This is crucial - QR codes can introduce subtle formatting issues # This is crucial - QR codes can introduce subtle formatting issues
try: try:
@@ -367,21 +365,21 @@ def extract_key_from_qr(image_data: bytes) -> Optional[str]:
except Exception: except Exception:
# If normalization fails, return None rather than broken PEM # If normalization fails, return None rather than broken PEM
return None return None
# Step 5: Final validation - ensure it still looks like PEM # Step 5: Final validation - ensure it still looks like PEM
if '-----BEGIN' in key_pem and '-----END' in key_pem: if '-----BEGIN' in key_pem and '-----END' in key_pem:
return key_pem return key_pem
return None return None
def extract_key_from_qr_file(filepath: str) -> Optional[str]: def extract_key_from_qr_file(filepath: str) -> str | None:
""" """
Extract RSA key from QR code image file. Extract RSA key from QR code image file.
Args: Args:
filepath: Path to image file containing QR code filepath: Path to image file containing QR code
Returns: Returns:
PEM-encoded RSA key string, or None if not found/invalid PEM-encoded RSA key string, or None if not found/invalid
""" """
@@ -393,21 +391,21 @@ def detect_and_crop_qr(
image_data: bytes, image_data: bytes,
padding_percent: float = QR_CROP_PADDING_PERCENT, padding_percent: float = QR_CROP_PADDING_PERCENT,
min_padding_px: int = QR_CROP_MIN_PADDING_PX min_padding_px: int = QR_CROP_MIN_PADDING_PX
) -> Optional[bytes]: ) -> bytes | None:
""" """
Detect QR code in image and crop to it, handling rotation. Detect QR code in image and crop to it, handling rotation.
Uses the QR code's corner coordinates to compute an axis-aligned Uses the QR code's corner coordinates to compute an axis-aligned
bounding box, then adds padding to ensure rotated QR codes aren't clipped. bounding box, then adds padding to ensure rotated QR codes aren't clipped.
Args: Args:
image_data: Input image bytes (PNG, JPG, etc.) image_data: Input image bytes (PNG, JPG, etc.)
padding_percent: Padding as fraction of QR size (default 10%) padding_percent: Padding as fraction of QR size (default 10%)
min_padding_px: Minimum padding in pixels (default 10) min_padding_px: Minimum padding in pixels (default 10)
Returns: Returns:
Cropped PNG image bytes, or None if no QR code found Cropped PNG image bytes, or None if no QR code found
Raises: Raises:
RuntimeError: If pyzbar library not available RuntimeError: If pyzbar library not available
""" """
@@ -416,27 +414,27 @@ def detect_and_crop_qr(
"pyzbar library not installed. Run: pip install pyzbar\n" "pyzbar library not installed. Run: pip install pyzbar\n"
"Also requires system library: sudo apt-get install libzbar0" "Also requires system library: sudo apt-get install libzbar0"
) )
try: try:
img: Image.Image = Image.open(io.BytesIO(image_data)) img: Image.Image = Image.open(io.BytesIO(image_data))
original_mode = img.mode original_mode = img.mode
# Convert for pyzbar detection # Convert for pyzbar detection
if img.mode not in ('RGB', 'L'): if img.mode not in ('RGB', 'L'):
detect_img = img.convert('RGB') detect_img = img.convert('RGB')
else: else:
detect_img = img detect_img = img
# Decode QR codes to get corner positions # Decode QR codes to get corner positions
decoded = pyzbar_decode(detect_img, symbols=[ZBarSymbol.QRCODE]) decoded = pyzbar_decode(detect_img, symbols=[ZBarSymbol.QRCODE])
if not decoded: if not decoded:
return None return None
# Get the polygon corners of the first QR code # Get the polygon corners of the first QR code
# pyzbar returns a Polygon with Point objects (x, y attributes) # pyzbar returns a Polygon with Point objects (x, y attributes)
polygon = decoded[0].polygon polygon = decoded[0].polygon
if len(polygon) < 4: if len(polygon) < 4:
# Fallback to rect if polygon not available # Fallback to rect if polygon not available
rect = decoded[0].rect rect = decoded[0].rect
@@ -448,25 +446,25 @@ def detect_and_crop_qr(
ys = [p.y for p in polygon] ys = [p.y for p in polygon]
min_x, max_x = min(xs), max(xs) min_x, max_x = min(xs), max(xs)
min_y, max_y = min(ys), max(ys) min_y, max_y = min(ys), max(ys)
# Calculate QR dimensions and padding # Calculate QR dimensions and padding
qr_width = max_x - min_x qr_width = max_x - min_x
qr_height = max_y - min_y qr_height = max_y - min_y
# Use larger dimension for padding calculation (handles rotation) # Use larger dimension for padding calculation (handles rotation)
qr_size = max(qr_width, qr_height) qr_size = max(qr_width, qr_height)
padding = max(int(qr_size * padding_percent), min_padding_px) padding = max(int(qr_size * padding_percent), min_padding_px)
# Calculate crop box with padding, clamped to image bounds # Calculate crop box with padding, clamped to image bounds
img_width, img_height = img.size img_width, img_height = img.size
crop_left = max(0, min_x - padding) crop_left = max(0, min_x - padding)
crop_top = max(0, min_y - padding) crop_top = max(0, min_y - padding)
crop_right = min(img_width, max_x + padding) crop_right = min(img_width, max_x + padding)
crop_bottom = min(img_height, max_y + padding) crop_bottom = min(img_height, max_y + padding)
# Crop the original image (preserves original mode/quality) # Crop the original image (preserves original mode/quality)
cropped = img.crop((crop_left, crop_top, crop_right, crop_bottom)) cropped = img.crop((crop_left, crop_top, crop_right, crop_bottom))
# Convert to PNG bytes # Convert to PNG bytes
buf = io.BytesIO() buf = io.BytesIO()
# Preserve transparency if present # Preserve transparency if present
@@ -476,7 +474,7 @@ def detect_and_crop_qr(
cropped.save(buf, format='PNG') cropped.save(buf, format='PNG')
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
except Exception as e: except Exception as e:
# Log for debugging but return None for clean API # Log for debugging but return None for clean API
import sys import sys
@@ -488,15 +486,15 @@ def detect_and_crop_qr_file(
filepath: str, filepath: str,
padding_percent: float = QR_CROP_PADDING_PERCENT, padding_percent: float = QR_CROP_PADDING_PERCENT,
min_padding_px: int = QR_CROP_MIN_PADDING_PX min_padding_px: int = QR_CROP_MIN_PADDING_PX
) -> Optional[bytes]: ) -> bytes | None:
""" """
Detect QR code in image file and crop to it. Detect QR code in image file and crop to it.
Args: Args:
filepath: Path to image file filepath: Path to image file
padding_percent: Padding as fraction of QR size (default 10%) padding_percent: Padding as fraction of QR size (default 10%)
min_padding_px: Minimum padding in pixels (default 10) min_padding_px: Minimum padding in pixels (default 10)
Returns: Returns:
Cropped PNG image bytes, or None if no QR code found Cropped PNG image bytes, or None if no QR code found
""" """

View File

@@ -20,22 +20,24 @@ Changes in v3.2.0:
import io import io
import struct import struct
from typing import Optional, Tuple, List, Union from typing import TYPE_CHECKING, Union
from PIL import Image
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from PIL import Image
if TYPE_CHECKING:
from .dct_steganography import DCTEmbedStats
from .models import EmbedStats, FilePayload
from .exceptions import CapacityError, ExtractionError, EmbeddingError
from .debug import debug
from .constants import ( from .constants import (
EMBED_MODE_LSB,
EMBED_MODE_DCT,
EMBED_MODE_AUTO, EMBED_MODE_AUTO,
EMBED_MODE_DCT,
EMBED_MODE_LSB,
VALID_EMBED_MODES, VALID_EMBED_MODES,
) )
from .debug import debug
from .exceptions import CapacityError, EmbeddingError
from .models import EmbedStats, FilePayload
# Lossless formats that preserve LSB data # Lossless formats that preserve LSB data
LOSSLESS_FORMATS = {'PNG', 'BMP', 'TIFF'} LOSSLESS_FORMATS = {'PNG', 'BMP', 'TIFF'}
@@ -103,10 +105,10 @@ def _get_dct_module():
def has_dct_support() -> bool: def has_dct_support() -> bool:
""" """
Check if DCT steganography mode is available. Check if DCT steganography mode is available.
Returns: Returns:
True if scipy is installed and DCT functions work True if scipy is installed and DCT functions work
Example: Example:
>>> if has_dct_support(): >>> if has_dct_support():
... result = encode(..., embed_mode='dct') ... result = encode(..., embed_mode='dct')
@@ -122,26 +124,26 @@ def has_dct_support() -> bool:
# FORMAT UTILITIES # FORMAT UTILITIES
# ============================================================================= # =============================================================================
def get_output_format(input_format: Optional[str]) -> Tuple[str, str]: def get_output_format(input_format: str | None) -> tuple[str, str]:
""" """
Determine the output format based on input format. Determine the output format based on input format.
Args: Args:
input_format: PIL format string of input image (e.g., 'JPEG', 'PNG') input_format: PIL format string of input image (e.g., 'JPEG', 'PNG')
Returns: Returns:
Tuple of (PIL format string, file extension) for output Tuple of (PIL format string, file extension) for output
Falls back to PNG for lossy or unknown formats. Falls back to PNG for lossy or unknown formats.
""" """
debug.validate(input_format is None or isinstance(input_format, str), debug.validate(input_format is None or isinstance(input_format, str),
"Input format must be string or None") "Input format must be string or None")
if input_format and input_format.upper() in LOSSLESS_FORMATS: if input_format and input_format.upper() in LOSSLESS_FORMATS:
fmt = input_format.upper() fmt = input_format.upper()
ext = FORMAT_TO_EXT.get(fmt, 'png') ext = FORMAT_TO_EXT.get(fmt, 'png')
debug.print(f"Using lossless format: {fmt} -> .{ext}") debug.print(f"Using lossless format: {fmt} -> .{ext}")
return fmt, ext return fmt, ext
debug.print(f"Input format {input_format} is lossy or unknown, defaulting to PNG") debug.print(f"Input format {input_format} is lossy or unknown, defaulting to PNG")
return 'PNG', 'png' return 'PNG', 'png'
@@ -151,20 +153,20 @@ def get_output_format(input_format: Optional[str]) -> Tuple[str, str]:
# ============================================================================= # =============================================================================
def will_fit( def will_fit(
payload: Union[str, bytes, FilePayload, int], payload: str | bytes | FilePayload | int,
carrier_image: bytes, carrier_image: bytes,
bits_per_channel: int = 1, bits_per_channel: int = 1,
include_compression_estimate: bool = True, include_compression_estimate: bool = True,
) -> dict: ) -> dict:
""" """
Check if a payload will fit in a carrier image (LSB mode). Check if a payload will fit in a carrier image (LSB mode).
Args: Args:
payload: Message string, raw bytes, FilePayload, or size in bytes payload: Message string, raw bytes, FilePayload, or size in bytes
carrier_image: Carrier image bytes carrier_image: Carrier image bytes
bits_per_channel: Bits to use per color channel (1-2) bits_per_channel: Bits to use per color channel (1-2)
include_compression_estimate: Estimate compressed size include_compression_estimate: Estimate compressed size
Returns: Returns:
Dict with fits, capacity, usage info Dict with fits, capacity, usage info
""" """
@@ -183,15 +185,15 @@ def will_fit(
else: else:
payload_data = payload payload_data = payload
payload_size = len(payload) payload_size = len(payload)
capacity = calculate_capacity(carrier_image, bits_per_channel) capacity = calculate_capacity(carrier_image, bits_per_channel)
# Estimate encrypted size with padding # Estimate encrypted size with padding
# Padding adds 64-319 bytes, rounded up to 256-byte boundary # Padding adds 64-319 bytes, rounded up to 256-byte boundary
# Average case: ~190 bytes padding # Average case: ~190 bytes padding
estimated_padding = 190 estimated_padding = 190
estimated_encrypted_size = payload_size + estimated_padding + ENCRYPTION_OVERHEAD estimated_encrypted_size = payload_size + estimated_padding + ENCRYPTION_OVERHEAD
compressed_estimate = None compressed_estimate = None
if include_compression_estimate and payload_data is not None and len(payload_data) >= 64: if include_compression_estimate and payload_data is not None and len(payload_data) >= 64:
try: try:
@@ -203,11 +205,11 @@ def will_fit(
estimated_encrypted_size = compressed_size + estimated_padding + ENCRYPTION_OVERHEAD estimated_encrypted_size = compressed_size + estimated_padding + ENCRYPTION_OVERHEAD
except Exception: except Exception:
pass pass
headroom = capacity - estimated_encrypted_size headroom = capacity - estimated_encrypted_size
fits = headroom >= 0 fits = headroom >= 0
usage_percent = (estimated_encrypted_size / capacity * 100) if capacity > 0 else 100.0 usage_percent = (estimated_encrypted_size / capacity * 100) if capacity > 0 else 100.0
return { return {
'fits': fits, 'fits': fits,
'payload_size': payload_size, 'payload_size': payload_size,
@@ -223,23 +225,23 @@ def will_fit(
def calculate_capacity(image_data: bytes, bits_per_channel: int = 1) -> int: def calculate_capacity(image_data: bytes, bits_per_channel: int = 1) -> int:
""" """
Calculate the maximum message capacity of an image (LSB mode). Calculate the maximum message capacity of an image (LSB mode).
Args: Args:
image_data: Image bytes image_data: Image bytes
bits_per_channel: Bits to use per color channel bits_per_channel: Bits to use per color channel
Returns: Returns:
Maximum bytes that can be embedded (minus overhead) Maximum bytes that can be embedded (minus overhead)
""" """
debug.validate(bits_per_channel in (1, 2), debug.validate(bits_per_channel in (1, 2),
f"bits_per_channel must be 1 or 2, got {bits_per_channel}") f"bits_per_channel must be 1 or 2, got {bits_per_channel}")
img_file = Image.open(io.BytesIO(image_data)) img_file = Image.open(io.BytesIO(image_data))
try: try:
num_pixels = img_file.size[0] * img_file.size[1] num_pixels = img_file.size[0] * img_file.size[1]
bits_per_pixel = 3 * bits_per_channel bits_per_pixel = 3 * bits_per_channel
max_bytes = (num_pixels * bits_per_pixel) // 8 max_bytes = (num_pixels * bits_per_pixel) // 8
capacity = max(0, max_bytes - ENCRYPTION_OVERHEAD) capacity = max(0, max_bytes - ENCRYPTION_OVERHEAD)
debug.print(f"LSB capacity: {capacity} bytes at {bits_per_channel} bit(s)/channel") debug.print(f"LSB capacity: {capacity} bytes at {bits_per_channel} bit(s)/channel")
return capacity return capacity
@@ -248,28 +250,28 @@ def calculate_capacity(image_data: bytes, bits_per_channel: int = 1) -> int:
def calculate_capacity_by_mode( def calculate_capacity_by_mode(
image_data: bytes, image_data: bytes,
embed_mode: str = EMBED_MODE_LSB, embed_mode: str = EMBED_MODE_LSB,
bits_per_channel: int = 1, bits_per_channel: int = 1,
) -> dict: ) -> dict:
""" """
Calculate capacity for specified embedding mode. Calculate capacity for specified embedding mode.
Args: Args:
image_data: Carrier image bytes image_data: Carrier image bytes
embed_mode: 'lsb' or 'dct' embed_mode: 'lsb' or 'dct'
bits_per_channel: Bits per channel for LSB mode bits_per_channel: Bits per channel for LSB mode
Returns: Returns:
Dict with capacity information Dict with capacity information
""" """
if embed_mode == EMBED_MODE_DCT: if embed_mode == EMBED_MODE_DCT:
if not has_dct_support(): if not has_dct_support():
raise ImportError("scipy required for DCT mode. Install: pip install scipy") raise ImportError("scipy required for DCT mode. Install: pip install scipy")
dct_mod = _get_dct_module() dct_mod = _get_dct_module()
dct_info = dct_mod.calculate_dct_capacity(image_data) dct_info = dct_mod.calculate_dct_capacity(image_data)
return { return {
'mode': EMBED_MODE_DCT, 'mode': EMBED_MODE_DCT,
'capacity_bytes': dct_info.usable_capacity_bytes, 'capacity_bytes': dct_info.usable_capacity_bytes,
@@ -285,7 +287,7 @@ def calculate_capacity_by_mode(
width, height = img.size width, height = img.size
finally: finally:
img.close() img.close()
return { return {
'mode': EMBED_MODE_LSB, 'mode': EMBED_MODE_LSB,
'capacity_bytes': capacity, 'capacity_bytes': capacity,
@@ -297,27 +299,27 @@ def calculate_capacity_by_mode(
def will_fit_by_mode( def will_fit_by_mode(
payload: Union[str, bytes, FilePayload, int], payload: str | bytes | FilePayload | int,
carrier_image: bytes, carrier_image: bytes,
embed_mode: str = EMBED_MODE_LSB, embed_mode: str = EMBED_MODE_LSB,
bits_per_channel: int = 1, bits_per_channel: int = 1,
) -> dict: ) -> dict:
""" """
Check if payload fits in specified mode. Check if payload fits in specified mode.
Args: Args:
payload: Message, bytes, FilePayload, or size in bytes payload: Message, bytes, FilePayload, or size in bytes
carrier_image: Carrier image bytes carrier_image: Carrier image bytes
embed_mode: 'lsb' or 'dct' embed_mode: 'lsb' or 'dct'
bits_per_channel: For LSB mode bits_per_channel: For LSB mode
Returns: Returns:
Dict with fits, capacity, usage info Dict with fits, capacity, usage info
""" """
if embed_mode == EMBED_MODE_DCT: if embed_mode == EMBED_MODE_DCT:
if not has_dct_support(): if not has_dct_support():
return {'fits': False, 'error': 'scipy not available', 'mode': EMBED_MODE_DCT} return {'fits': False, 'error': 'scipy not available', 'mode': EMBED_MODE_DCT}
if isinstance(payload, int): if isinstance(payload, int):
payload_size = payload payload_size = payload
elif isinstance(payload, str): elif isinstance(payload, str):
@@ -326,16 +328,16 @@ def will_fit_by_mode(
payload_size = len(payload.data) payload_size = len(payload.data)
else: else:
payload_size = len(payload) payload_size = len(payload)
estimated_size = payload_size + ENCRYPTION_OVERHEAD + 190 # padding estimate estimated_size = payload_size + ENCRYPTION_OVERHEAD + 190 # padding estimate
dct_mod = _get_dct_module() dct_mod = _get_dct_module()
fits = dct_mod.will_fit_dct(estimated_size, carrier_image) fits = dct_mod.will_fit_dct(estimated_size, carrier_image)
capacity_info = dct_mod.calculate_dct_capacity(carrier_image) capacity_info = dct_mod.calculate_dct_capacity(carrier_image)
capacity = capacity_info.usable_capacity_bytes capacity = capacity_info.usable_capacity_bytes
usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0 usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0
return { return {
'fits': fits, 'fits': fits,
'payload_size': payload_size, 'payload_size': payload_size,
@@ -351,7 +353,7 @@ def will_fit_by_mode(
def get_available_modes() -> dict: def get_available_modes() -> dict:
""" """
Get available embedding modes and their status. Get available embedding modes and their status.
Returns: Returns:
Dict mapping mode name to availability info Dict mapping mode name to availability info
""" """
@@ -375,10 +377,10 @@ def get_available_modes() -> dict:
def compare_modes(image_data: bytes) -> dict: def compare_modes(image_data: bytes) -> dict:
""" """
Compare embedding modes for a carrier image. Compare embedding modes for a carrier image.
Args: Args:
image_data: Carrier image bytes image_data: Carrier image bytes
Returns: Returns:
Dict with comparison of LSB vs DCT modes Dict with comparison of LSB vs DCT modes
""" """
@@ -387,9 +389,9 @@ def compare_modes(image_data: bytes) -> dict:
width, height = img.size width, height = img.size
finally: finally:
img.close() img.close()
lsb_bytes = calculate_capacity(image_data, 1) lsb_bytes = calculate_capacity(image_data, 1)
if has_dct_support(): if has_dct_support():
dct_mod = _get_dct_module() dct_mod = _get_dct_module()
dct_info = dct_mod.calculate_dct_capacity(image_data) dct_info = dct_mod.calculate_dct_capacity(image_data)
@@ -399,7 +401,7 @@ def compare_modes(image_data: bytes) -> dict:
safe_blocks = (height // 8) * (width // 8) safe_blocks = (height // 8) * (width // 8)
dct_bytes = (safe_blocks * 16) // 8 # Estimated dct_bytes = (safe_blocks * 16) // 8 # Estimated
dct_available = False dct_available = False
return { return {
'width': width, 'width': width,
'height': height, 'height': height,
@@ -424,62 +426,62 @@ def compare_modes(image_data: bytes) -> dict:
# ============================================================================= # =============================================================================
@debug.time @debug.time
def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> List[int]: def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list[int]:
""" """
Generate pseudo-random pixel indices for embedding. Generate pseudo-random pixel indices for embedding.
Uses ChaCha20 as a CSPRNG seeded by the key to deterministically Uses ChaCha20 as a CSPRNG seeded by the key to deterministically
select which pixels will hold hidden data. select which pixels will hold hidden data.
""" """
debug.validate(len(key) == 32, f"Pixel key must be 32 bytes, got {len(key)}") debug.validate(len(key) == 32, f"Pixel key must be 32 bytes, got {len(key)}")
debug.validate(num_pixels > 0, f"Number of pixels must be positive, got {num_pixels}") debug.validate(num_pixels > 0, f"Number of pixels must be positive, got {num_pixels}")
debug.validate(num_needed > 0, f"Number needed must be positive, got {num_needed}") debug.validate(num_needed > 0, f"Number needed must be positive, got {num_needed}")
debug.validate(num_needed <= num_pixels, debug.validate(num_needed <= num_pixels,
f"Cannot select {num_needed} pixels from {num_pixels} available") f"Cannot select {num_needed} pixels from {num_pixels} available")
debug.print(f"Generating {num_needed} pixel indices from {num_pixels} total pixels") debug.print(f"Generating {num_needed} pixel indices from {num_pixels} total pixels")
if num_needed >= num_pixels // 2: if num_needed >= num_pixels // 2:
debug.print(f"Using full shuffle (needed {num_needed}/{num_pixels} pixels)") debug.print(f"Using full shuffle (needed {num_needed}/{num_pixels} pixels)")
nonce = b'\x00' * 16 nonce = b'\x00' * 16
cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend()) cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend())
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
indices = list(range(num_pixels)) indices = list(range(num_pixels))
random_bytes = encryptor.update(b'\x00' * (num_pixels * 4)) random_bytes = encryptor.update(b'\x00' * (num_pixels * 4))
for i in range(num_pixels - 1, 0, -1): for i in range(num_pixels - 1, 0, -1):
j_bytes = random_bytes[(num_pixels - 1 - i) * 4:(num_pixels - i) * 4] j_bytes = random_bytes[(num_pixels - 1 - i) * 4:(num_pixels - i) * 4]
j = int.from_bytes(j_bytes, 'big') % (i + 1) j = int.from_bytes(j_bytes, 'big') % (i + 1)
indices[i], indices[j] = indices[j], indices[i] indices[i], indices[j] = indices[j], indices[i]
selected = indices[:num_needed] selected = indices[:num_needed]
debug.print(f"Generated {len(selected)} indices via shuffle") debug.print(f"Generated {len(selected)} indices via shuffle")
return selected return selected
debug.print(f"Using optimized selection (needed {num_needed}/{num_pixels} pixels)") debug.print(f"Using optimized selection (needed {num_needed}/{num_pixels} pixels)")
selected = [] selected = []
used = set() used = set()
nonce = b'\x00' * 16 nonce = b'\x00' * 16
cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend()) cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend())
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
bytes_needed = (num_needed * 2) * 4 bytes_needed = (num_needed * 2) * 4
random_bytes = encryptor.update(b'\x00' * bytes_needed) random_bytes = encryptor.update(b'\x00' * bytes_needed)
byte_offset = 0 byte_offset = 0
collisions = 0 collisions = 0
while len(selected) < num_needed and byte_offset < len(random_bytes) - 4: while len(selected) < num_needed and byte_offset < len(random_bytes) - 4:
idx = int.from_bytes(random_bytes[byte_offset:byte_offset + 4], 'big') % num_pixels idx = int.from_bytes(random_bytes[byte_offset:byte_offset + 4], 'big') % num_pixels
byte_offset += 4 byte_offset += 4
if idx not in used: if idx not in used:
used.add(idx) used.add(idx)
selected.append(idx) selected.append(idx)
else: else:
collisions += 1 collisions += 1
if len(selected) < num_needed: if len(selected) < num_needed:
debug.print(f"Need {num_needed - len(selected)} more indices, generating...") debug.print(f"Need {num_needed - len(selected)} more indices, generating...")
extra_needed = num_needed - len(selected) extra_needed = num_needed - len(selected)
@@ -491,7 +493,7 @@ def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> List
selected.append(idx) selected.append(idx)
if len(selected) == num_needed: if len(selected) == num_needed:
break break
debug.print(f"Generated {len(selected)} indices with {collisions} collisions") debug.print(f"Generated {len(selected)} indices with {collisions} collisions")
debug.validate(len(selected) == num_needed, debug.validate(len(selected) == num_needed,
f"Failed to generate enough indices: {len(selected)}/{num_needed}") f"Failed to generate enough indices: {len(selected)}/{num_needed}")
@@ -508,14 +510,14 @@ def embed_in_image(
image_data: bytes, image_data: bytes,
pixel_key: bytes, pixel_key: bytes,
bits_per_channel: int = 1, bits_per_channel: int = 1,
output_format: Optional[str] = None, output_format: str | None = None,
embed_mode: str = EMBED_MODE_LSB, embed_mode: str = EMBED_MODE_LSB,
dct_output_format: str = DCT_OUTPUT_PNG, dct_output_format: str = DCT_OUTPUT_PNG,
dct_color_mode: str = 'grayscale', dct_color_mode: str = 'grayscale',
) -> Tuple[bytes, Union[EmbedStats, 'DCTEmbedStats'], str]: ) -> tuple[bytes, Union[EmbedStats, 'DCTEmbedStats'], str]:
""" """
Embed data into an image using specified mode. Embed data into an image using specified mode.
Args: Args:
data: Data to embed (encrypted payload) data: Data to embed (encrypted payload)
image_data: Carrier image bytes image_data: Carrier image bytes
@@ -525,19 +527,19 @@ def embed_in_image(
embed_mode: 'lsb' (default) or 'dct' embed_mode: 'lsb' (default) or 'dct'
dct_output_format: For DCT mode - 'png' (lossless) or 'jpeg' (smaller) dct_output_format: For DCT mode - 'png' (lossless) or 'jpeg' (smaller)
dct_color_mode: For DCT mode - 'grayscale' (default) or 'color' (preserves colors) dct_color_mode: For DCT mode - 'grayscale' (default) or 'color' (preserves colors)
Returns: Returns:
Tuple of (stego image bytes, stats, file extension) Tuple of (stego image bytes, stats, file extension)
Raises: Raises:
CapacityError: If data won't fit CapacityError: If data won't fit
EmbeddingError: If embedding fails EmbeddingError: If embedding fails
ImportError: If DCT mode requested but scipy unavailable ImportError: If DCT mode requested but scipy unavailable
""" """
debug.print(f"embed_in_image: mode={embed_mode}, data={len(data)} bytes") debug.print(f"embed_in_image: mode={embed_mode}, data={len(data)} bytes")
debug.validate(embed_mode in VALID_EMBED_MODES, debug.validate(embed_mode in VALID_EMBED_MODES,
f"Invalid embed_mode: {embed_mode}. Use 'lsb' or 'dct'") f"Invalid embed_mode: {embed_mode}. Use 'lsb' or 'dct'")
# DCT MODE # DCT MODE
if embed_mode == EMBED_MODE_DCT: if embed_mode == EMBED_MODE_DCT:
if not has_dct_support(): if not has_dct_support():
@@ -545,38 +547,38 @@ def embed_in_image(
"scipy is required for DCT embedding mode. " "scipy is required for DCT embedding mode. "
"Install with: pip install scipy" "Install with: pip install scipy"
) )
# Validate DCT output format # Validate DCT output format
if dct_output_format not in (DCT_OUTPUT_PNG, DCT_OUTPUT_JPEG): if dct_output_format not in (DCT_OUTPUT_PNG, DCT_OUTPUT_JPEG):
debug.print(f"Invalid dct_output_format '{dct_output_format}', defaulting to PNG") debug.print(f"Invalid dct_output_format '{dct_output_format}', defaulting to PNG")
dct_output_format = DCT_OUTPUT_PNG dct_output_format = DCT_OUTPUT_PNG
# Validate DCT color mode (v3.0.1) # Validate DCT color mode (v3.0.1)
if dct_color_mode not in ('grayscale', 'color'): if dct_color_mode not in ('grayscale', 'color'):
debug.print(f"Invalid dct_color_mode '{dct_color_mode}', defaulting to grayscale") debug.print(f"Invalid dct_color_mode '{dct_color_mode}', defaulting to grayscale")
dct_color_mode = 'grayscale' dct_color_mode = 'grayscale'
dct_mod = _get_dct_module() dct_mod = _get_dct_module()
# Pass output_format and color_mode to DCT module (v3.0.1) # Pass output_format and color_mode to DCT module (v3.0.1)
stego_bytes, dct_stats = dct_mod.embed_in_dct( stego_bytes, dct_stats = dct_mod.embed_in_dct(
data, data,
image_data, image_data,
pixel_key, pixel_key,
output_format=dct_output_format, output_format=dct_output_format,
color_mode=dct_color_mode, color_mode=dct_color_mode,
) )
# Determine extension based on output format # Determine extension based on output format
if dct_output_format == DCT_OUTPUT_JPEG: if dct_output_format == DCT_OUTPUT_JPEG:
ext = 'jpg' ext = 'jpg'
else: else:
ext = 'png' ext = 'png'
debug.print(f"DCT embedding complete: {dct_output_format.upper()} output, " debug.print(f"DCT embedding complete: {dct_output_format.upper()} output, "
f"color_mode={dct_color_mode}, ext={ext}") f"color_mode={dct_color_mode}, ext={ext}")
return stego_bytes, dct_stats, ext return stego_bytes, dct_stats, ext
# LSB MODE # LSB MODE
return _embed_lsb(data, image_data, pixel_key, bits_per_channel, output_format) return _embed_lsb(data, image_data, pixel_key, bits_per_channel, output_format)
@@ -586,75 +588,75 @@ def _embed_lsb(
image_data: bytes, image_data: bytes,
pixel_key: bytes, pixel_key: bytes,
bits_per_channel: int = 1, bits_per_channel: int = 1,
output_format: Optional[str] = None, output_format: str | None = None,
) -> Tuple[bytes, EmbedStats, str]: ) -> tuple[bytes, EmbedStats, str]:
""" """
Embed data using LSB steganography (internal implementation). Embed data using LSB steganography (internal implementation).
""" """
debug.print(f"LSB embedding {len(data)} bytes into image") debug.print(f"LSB embedding {len(data)} bytes into image")
debug.data(pixel_key, "Pixel key for embedding") debug.data(pixel_key, "Pixel key for embedding")
debug.validate(bits_per_channel in (1, 2), debug.validate(bits_per_channel in (1, 2),
f"bits_per_channel must be 1 or 2, got {bits_per_channel}") f"bits_per_channel must be 1 or 2, got {bits_per_channel}")
debug.validate(len(pixel_key) == 32, debug.validate(len(pixel_key) == 32,
f"Pixel key must be 32 bytes, got {len(pixel_key)}") f"Pixel key must be 32 bytes, got {len(pixel_key)}")
img_file = None img_file = None
img = None img = None
stego_img = None stego_img = None
try: try:
img_file = Image.open(io.BytesIO(image_data)) img_file = Image.open(io.BytesIO(image_data))
input_format = img_file.format input_format = img_file.format
debug.print(f"Carrier image: {img_file.size[0]}x{img_file.size[1]}, format: {input_format}") debug.print(f"Carrier image: {img_file.size[0]}x{img_file.size[1]}, format: {input_format}")
img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy() img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy()
if img_file.mode != 'RGB': if img_file.mode != 'RGB':
debug.print(f"Converting image from {img_file.mode} to RGB") debug.print(f"Converting image from {img_file.mode} to RGB")
pixels = list(img.getdata()) pixels = list(img.getdata())
num_pixels = len(pixels) num_pixels = len(pixels)
bits_per_pixel = 3 * bits_per_channel bits_per_pixel = 3 * bits_per_channel
max_bytes = (num_pixels * bits_per_pixel) // 8 max_bytes = (num_pixels * bits_per_pixel) // 8
debug.print(f"Image capacity: {max_bytes} bytes at {bits_per_channel} bit(s)/channel") debug.print(f"Image capacity: {max_bytes} bytes at {bits_per_channel} bit(s)/channel")
data_with_len = struct.pack('>I', len(data)) + data data_with_len = struct.pack('>I', len(data)) + data
if len(data_with_len) > max_bytes: if len(data_with_len) > max_bytes:
debug.print(f"Capacity error: need {len(data_with_len)}, have {max_bytes}") debug.print(f"Capacity error: need {len(data_with_len)}, have {max_bytes}")
raise CapacityError(len(data_with_len), max_bytes) raise CapacityError(len(data_with_len), max_bytes)
debug.print(f"Total data to embed: {len(data_with_len)} bytes " debug.print(f"Total data to embed: {len(data_with_len)} bytes "
f"({len(data_with_len)/max_bytes*100:.1f}% of capacity)") f"({len(data_with_len)/max_bytes*100:.1f}% of capacity)")
binary_data = ''.join(format(b, '08b') for b in data_with_len) binary_data = ''.join(format(b, '08b') for b in data_with_len)
pixels_needed = (len(binary_data) + bits_per_pixel - 1) // bits_per_pixel pixels_needed = (len(binary_data) + bits_per_pixel - 1) // bits_per_pixel
debug.print(f"Need {pixels_needed} pixels to embed {len(binary_data)} bits") debug.print(f"Need {pixels_needed} pixels to embed {len(binary_data)} bits")
selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed) selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed)
new_pixels = list(pixels) new_pixels = list(pixels)
clear_mask = 0xFF ^ ((1 << bits_per_channel) - 1) clear_mask = 0xFF ^ ((1 << bits_per_channel) - 1)
bit_idx = 0 bit_idx = 0
modified_pixels = 0 modified_pixels = 0
for pixel_idx in selected_indices: for pixel_idx in selected_indices:
if bit_idx >= len(binary_data): if bit_idx >= len(binary_data):
break break
r, g, b = new_pixels[pixel_idx] r, g, b = new_pixels[pixel_idx]
modified = False modified = False
for channel_idx, channel_val in enumerate([r, g, b]): for channel_idx, channel_val in enumerate([r, g, b]):
if bit_idx >= len(binary_data): if bit_idx >= len(binary_data):
break break
bits = binary_data[bit_idx:bit_idx + bits_per_channel].ljust(bits_per_channel, '0') bits = binary_data[bit_idx:bit_idx + bits_per_channel].ljust(bits_per_channel, '0')
new_val = (channel_val & clear_mask) | int(bits, 2) new_val = (channel_val & clear_mask) | int(bits, 2)
if channel_val != new_val: if channel_val != new_val:
modified = True modified = True
if channel_idx == 0: if channel_idx == 0:
@@ -663,18 +665,18 @@ def _embed_lsb(
g = new_val g = new_val
else: else:
b = new_val b = new_val
bit_idx += bits_per_channel bit_idx += bits_per_channel
if modified: if modified:
new_pixels[pixel_idx] = (r, g, b) new_pixels[pixel_idx] = (r, g, b)
modified_pixels += 1 modified_pixels += 1
debug.print(f"Modified {modified_pixels} pixels (out of {len(selected_indices)} selected)") debug.print(f"Modified {modified_pixels} pixels (out of {len(selected_indices)} selected)")
stego_img = Image.new('RGB', img.size) stego_img = Image.new('RGB', img.size)
stego_img.putdata(new_pixels) stego_img.putdata(new_pixels)
if output_format: if output_format:
out_fmt = output_format.upper() out_fmt = output_format.upper()
out_ext = FORMAT_TO_EXT.get(out_fmt, 'png') out_ext = FORMAT_TO_EXT.get(out_fmt, 'png')
@@ -682,21 +684,21 @@ def _embed_lsb(
else: else:
out_fmt, out_ext = get_output_format(input_format) out_fmt, out_ext = get_output_format(input_format)
debug.print(f"Auto-selected output format: {out_fmt}") debug.print(f"Auto-selected output format: {out_fmt}")
output = io.BytesIO() output = io.BytesIO()
stego_img.save(output, out_fmt) stego_img.save(output, out_fmt)
output.seek(0) output.seek(0)
stats = EmbedStats( stats = EmbedStats(
pixels_modified=modified_pixels, pixels_modified=modified_pixels,
total_pixels=num_pixels, total_pixels=num_pixels,
capacity_used=len(data_with_len) / max_bytes, capacity_used=len(data_with_len) / max_bytes,
bytes_embedded=len(data_with_len) bytes_embedded=len(data_with_len)
) )
debug.print(f"LSB embedding complete: {out_fmt} image, {len(output.getvalue())} bytes") debug.print(f"LSB embedding complete: {out_fmt} image, {len(output.getvalue())} bytes")
return output.getvalue(), stats, out_ext return output.getvalue(), stats, out_ext
except CapacityError: except CapacityError:
raise raise
except Exception as e: except Exception as e:
@@ -722,50 +724,50 @@ def extract_from_image(
pixel_key: bytes, pixel_key: bytes,
bits_per_channel: int = 1, bits_per_channel: int = 1,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
) -> Optional[bytes]: ) -> bytes | None:
""" """
Extract hidden data from a stego image. Extract hidden data from a stego image.
Args: Args:
image_data: Stego image bytes image_data: Stego image bytes
pixel_key: Key for pixel/coefficient selection (must match encoding) pixel_key: Key for pixel/coefficient selection (must match encoding)
bits_per_channel: Bits per channel (LSB mode only) bits_per_channel: Bits per channel (LSB mode only)
embed_mode: 'auto' (try both), 'lsb', or 'dct' embed_mode: 'auto' (try both), 'lsb', or 'dct'
Returns: Returns:
Extracted data bytes, or None if extraction fails Extracted data bytes, or None if extraction fails
""" """
debug.print(f"extract_from_image: mode={embed_mode}") debug.print(f"extract_from_image: mode={embed_mode}")
# AUTO MODE: Try LSB first, then DCT # AUTO MODE: Try LSB first, then DCT
if embed_mode == EMBED_MODE_AUTO: if embed_mode == EMBED_MODE_AUTO:
result = _extract_lsb(image_data, pixel_key, bits_per_channel) result = _extract_lsb(image_data, pixel_key, bits_per_channel)
if result is not None: if result is not None:
debug.print("Auto-detect: LSB extraction succeeded") debug.print("Auto-detect: LSB extraction succeeded")
return result return result
if has_dct_support(): if has_dct_support():
debug.print("Auto-detect: LSB failed, trying DCT") debug.print("Auto-detect: LSB failed, trying DCT")
result = _extract_dct(image_data, pixel_key) result = _extract_dct(image_data, pixel_key)
if result is not None: if result is not None:
debug.print("Auto-detect: DCT extraction succeeded") debug.print("Auto-detect: DCT extraction succeeded")
return result return result
debug.print("Auto-detect: All modes failed") debug.print("Auto-detect: All modes failed")
return None return None
# EXPLICIT DCT MODE # EXPLICIT DCT MODE
elif embed_mode == EMBED_MODE_DCT: elif embed_mode == EMBED_MODE_DCT:
if not has_dct_support(): if not has_dct_support():
raise ImportError("scipy required for DCT mode") raise ImportError("scipy required for DCT mode")
return _extract_dct(image_data, pixel_key) return _extract_dct(image_data, pixel_key)
# EXPLICIT LSB MODE # EXPLICIT LSB MODE
else: else:
return _extract_lsb(image_data, pixel_key, bits_per_channel) return _extract_lsb(image_data, pixel_key, bits_per_channel)
def _extract_dct(image_data: bytes, pixel_key: bytes) -> Optional[bytes]: def _extract_dct(image_data: bytes, pixel_key: bytes) -> bytes | None:
"""Extract using DCT mode.""" """Extract using DCT mode."""
try: try:
dct_mod = _get_dct_module() dct_mod = _get_dct_module()
@@ -779,7 +781,7 @@ def _extract_lsb(
image_data: bytes, image_data: bytes,
pixel_key: bytes, pixel_key: bytes,
bits_per_channel: int = 1 bits_per_channel: int = 1
) -> Optional[bytes]: ) -> bytes | None:
""" """
Extract using LSB mode (internal implementation). Extract using LSB mode (internal implementation).
""" """
@@ -787,82 +789,82 @@ def _extract_lsb(
debug.data(pixel_key, "Pixel key for extraction") debug.data(pixel_key, "Pixel key for extraction")
debug.validate(bits_per_channel in (1, 2), debug.validate(bits_per_channel in (1, 2),
f"bits_per_channel must be 1 or 2, got {bits_per_channel}") f"bits_per_channel must be 1 or 2, got {bits_per_channel}")
img_file = None img_file = None
img = None img = None
try: try:
img_file = Image.open(io.BytesIO(image_data)) img_file = Image.open(io.BytesIO(image_data))
debug.print(f"Image: {img_file.size[0]}x{img_file.size[1]}, format: {img_file.format}") debug.print(f"Image: {img_file.size[0]}x{img_file.size[1]}, format: {img_file.format}")
img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy() img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy()
if img_file.mode != 'RGB': if img_file.mode != 'RGB':
debug.print(f"Converting image from {img_file.mode} to RGB") debug.print(f"Converting image from {img_file.mode} to RGB")
pixels = list(img.getdata()) pixels = list(img.getdata())
num_pixels = len(pixels) num_pixels = len(pixels)
bits_per_pixel = 3 * bits_per_channel bits_per_pixel = 3 * bits_per_channel
debug.print(f"Image has {num_pixels} pixels, {bits_per_pixel} bits/pixel") debug.print(f"Image has {num_pixels} pixels, {bits_per_pixel} bits/pixel")
initial_pixels = (32 + bits_per_pixel - 1) // bits_per_pixel + 10 initial_pixels = (32 + bits_per_pixel - 1) // bits_per_pixel + 10
debug.print(f"Extracting initial {initial_pixels} pixels to find length") debug.print(f"Extracting initial {initial_pixels} pixels to find length")
initial_indices = generate_pixel_indices(pixel_key, num_pixels, initial_pixels) initial_indices = generate_pixel_indices(pixel_key, num_pixels, initial_pixels)
binary_data = '' binary_data = ''
for pixel_idx in initial_indices: for pixel_idx in initial_indices:
r, g, b = pixels[pixel_idx] r, g, b = pixels[pixel_idx]
for channel in [r, g, b]: for channel in [r, g, b]:
for bit_pos in range(bits_per_channel - 1, -1, -1): for bit_pos in range(bits_per_channel - 1, -1, -1):
binary_data += str((channel >> bit_pos) & 1) binary_data += str((channel >> bit_pos) & 1)
try: try:
length_bits = binary_data[:32] length_bits = binary_data[:32]
if len(length_bits) < 32: if len(length_bits) < 32:
debug.print(f"Not enough bits for length: {len(length_bits)}/32") debug.print(f"Not enough bits for length: {len(length_bits)}/32")
return None return None
data_length = struct.unpack('>I', int(length_bits, 2).to_bytes(4, 'big'))[0] data_length = struct.unpack('>I', int(length_bits, 2).to_bytes(4, 'big'))[0]
debug.print(f"Extracted length: {data_length} bytes") debug.print(f"Extracted length: {data_length} bytes")
except Exception as e: except Exception as e:
debug.print(f"Failed to parse length: {e}") debug.print(f"Failed to parse length: {e}")
return None return None
max_possible = (num_pixels * bits_per_pixel) // 8 - 4 max_possible = (num_pixels * bits_per_pixel) // 8 - 4
if data_length > max_possible or data_length < 10: if data_length > max_possible or data_length < 10:
debug.print(f"Invalid data length: {data_length} (max possible: {max_possible})") debug.print(f"Invalid data length: {data_length} (max possible: {max_possible})")
return None return None
total_bits = (4 + data_length) * 8 total_bits = (4 + data_length) * 8
pixels_needed = (total_bits + bits_per_pixel - 1) // bits_per_pixel pixels_needed = (total_bits + bits_per_pixel - 1) // bits_per_pixel
debug.print(f"Need {pixels_needed} pixels to extract {data_length} bytes") debug.print(f"Need {pixels_needed} pixels to extract {data_length} bytes")
selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed) selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed)
binary_data = '' binary_data = ''
for pixel_idx in selected_indices: for pixel_idx in selected_indices:
r, g, b = pixels[pixel_idx] r, g, b = pixels[pixel_idx]
for channel in [r, g, b]: for channel in [r, g, b]:
for bit_pos in range(bits_per_channel - 1, -1, -1): for bit_pos in range(bits_per_channel - 1, -1, -1):
binary_data += str((channel >> bit_pos) & 1) binary_data += str((channel >> bit_pos) & 1)
data_bits = binary_data[32:32 + (data_length * 8)] data_bits = binary_data[32:32 + (data_length * 8)]
if len(data_bits) < data_length * 8: if len(data_bits) < data_length * 8:
debug.print(f"Insufficient bits: {len(data_bits)} < {data_length * 8}") debug.print(f"Insufficient bits: {len(data_bits)} < {data_length * 8}")
return None return None
data_bytes = bytearray() data_bytes = bytearray()
for i in range(0, len(data_bits), 8): for i in range(0, len(data_bits), 8):
byte_bits = data_bits[i:i + 8] byte_bits = data_bits[i:i + 8]
if len(byte_bits) == 8: if len(byte_bits) == 8:
data_bytes.append(int(byte_bits, 2)) data_bytes.append(int(byte_bits, 2))
debug.print(f"LSB successfully extracted {len(data_bytes)} bytes") debug.print(f"LSB successfully extracted {len(data_bytes)} bytes")
return bytes(data_bytes) return bytes(data_bytes)
except Exception as e: except Exception as e:
debug.exception(e, "extract_lsb") debug.exception(e, "extract_lsb")
return None return None
@@ -878,7 +880,7 @@ def _extract_lsb(
# UTILITY FUNCTIONS # UTILITY FUNCTIONS
# ============================================================================= # =============================================================================
def get_image_dimensions(image_data: bytes) -> Tuple[int, int]: def get_image_dimensions(image_data: bytes) -> tuple[int, int]:
"""Get image dimensions without loading full image.""" """Get image dimensions without loading full image."""
debug.validate(len(image_data) > 0, "Image data cannot be empty") debug.validate(len(image_data) > 0, "Image data cannot be empty")
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
@@ -890,7 +892,7 @@ def get_image_dimensions(image_data: bytes) -> Tuple[int, int]:
img.close() img.close()
def get_image_format(image_data: bytes) -> Optional[str]: def get_image_format(image_data: bytes) -> str | None:
"""Get image format (PIL format string like 'PNG', 'JPEG').""" """Get image format (PIL format string like 'PNG', 'JPEG')."""
try: try:
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))

View File

@@ -9,9 +9,8 @@ import os
import random import random
import secrets import secrets
import shutil import shutil
from datetime import date, datetime from datetime import date
from pathlib import Path from pathlib import Path
from typing import Optional, Union
from PIL import Image from PIL import Image
@@ -22,98 +21,98 @@ from .debug import debug
def strip_image_metadata(image_data: bytes, output_format: str = 'PNG') -> bytes: def strip_image_metadata(image_data: bytes, output_format: str = 'PNG') -> bytes:
""" """
Remove all metadata (EXIF, ICC profiles, etc.) from an image. Remove all metadata (EXIF, ICC profiles, etc.) from an image.
Creates a fresh image with only pixel data - no EXIF, GPS coordinates, Creates a fresh image with only pixel data - no EXIF, GPS coordinates,
camera info, timestamps, or other potentially sensitive metadata. camera info, timestamps, or other potentially sensitive metadata.
Args: Args:
image_data: Raw image bytes image_data: Raw image bytes
output_format: Output format ('PNG', 'BMP', 'TIFF') output_format: Output format ('PNG', 'BMP', 'TIFF')
Returns: Returns:
Clean image bytes with no metadata Clean image bytes with no metadata
Example: Example:
>>> clean = strip_image_metadata(photo_bytes) >>> clean = strip_image_metadata(photo_bytes)
>>> # EXIF data is now removed >>> # EXIF data is now removed
""" """
debug.print(f"Stripping metadata, output format: {output_format}") debug.print(f"Stripping metadata, output format: {output_format}")
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
# Convert to RGB if needed (handles RGBA, P, L, etc.) # Convert to RGB if needed (handles RGBA, P, L, etc.)
if img.mode not in ('RGB', 'RGBA'): if img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB') img = img.convert('RGB')
# Create fresh image - this discards all metadata # Create fresh image - this discards all metadata
clean = Image.new(img.mode, img.size) clean = Image.new(img.mode, img.size)
clean.putdata(list(img.getdata())) clean.putdata(list(img.getdata()))
output = io.BytesIO() output = io.BytesIO()
clean.save(output, output_format.upper()) clean.save(output, output_format.upper())
output.seek(0) output.seek(0)
debug.print(f"Metadata stripped: {len(image_data)} -> {len(output.getvalue())} bytes") debug.print(f"Metadata stripped: {len(image_data)} -> {len(output.getvalue())} bytes")
return output.getvalue() return output.getvalue()
def generate_filename( def generate_filename(
date_str: Optional[str] = None, date_str: str | None = None,
prefix: str = "", prefix: str = "",
extension: str = "png" extension: str = "png"
) -> str: ) -> str:
""" """
Generate a filename for stego images. Generate a filename for stego images.
Format: {prefix}{random}_{YYYYMMDD}.{extension} Format: {prefix}{random}_{YYYYMMDD}.{extension}
Args: Args:
date_str: Date string (YYYY-MM-DD), defaults to today date_str: Date string (YYYY-MM-DD), defaults to today
prefix: Optional prefix prefix: Optional prefix
extension: File extension without dot (default: 'png') extension: File extension without dot (default: 'png')
Returns: Returns:
Filename string Filename string
Example: Example:
>>> generate_filename("2023-12-25", "secret_", "png") >>> generate_filename("2023-12-25", "secret_", "png")
"secret_a1b2c3d4_20231225.png" "secret_a1b2c3d4_20231225.png"
""" """
debug.validate(bool(extension) and '.' not in extension, debug.validate(bool(extension) and '.' not in extension,
f"Extension must not contain dot, got '{extension}'") f"Extension must not contain dot, got '{extension}'")
if date_str is None: if date_str is None:
date_str = date.today().isoformat() date_str = date.today().isoformat()
date_compact = date_str.replace('-', '') date_compact = date_str.replace('-', '')
random_hex = secrets.token_hex(4) random_hex = secrets.token_hex(4)
# Ensure extension doesn't have a leading dot # Ensure extension doesn't have a leading dot
extension = extension.lstrip('.') extension = extension.lstrip('.')
filename = f"{prefix}{random_hex}_{date_compact}.{extension}" filename = f"{prefix}{random_hex}_{date_compact}.{extension}"
debug.print(f"Generated filename: {filename}") debug.print(f"Generated filename: {filename}")
return filename return filename
def parse_date_from_filename(filename: str) -> Optional[str]: def parse_date_from_filename(filename: str) -> str | None:
""" """
Extract date from a stego filename. Extract date from a stego filename.
Looks for patterns like _20251227 or _2025-12-27 Looks for patterns like _20251227 or _2025-12-27
Args: Args:
filename: Filename to parse filename: Filename to parse
Returns: Returns:
Date string (YYYY-MM-DD) or None Date string (YYYY-MM-DD) or None
Example: Example:
>>> parse_date_from_filename("secret_a1b2c3d4_20231225.png") >>> parse_date_from_filename("secret_a1b2c3d4_20231225.png")
"2023-12-25" "2023-12-25"
""" """
import re import re
# Try YYYYMMDD format # Try YYYYMMDD format
match = re.search(r'_(\d{4})(\d{2})(\d{2})(?:\.|$)', filename) match = re.search(r'_(\d{4})(\d{2})(\d{2})(?:\.|$)', filename)
if match: if match:
@@ -121,7 +120,7 @@ def parse_date_from_filename(filename: str) -> Optional[str]:
date_str = f"{year}-{month}-{day}" date_str = f"{year}-{month}-{day}"
debug.print(f"Parsed date (compact): {date_str}") debug.print(f"Parsed date (compact): {date_str}")
return date_str return date_str
# Try YYYY-MM-DD format # Try YYYY-MM-DD format
match = re.search(r'_(\d{4})-(\d{2})-(\d{2})(?:\.|$)', filename) match = re.search(r'_(\d{4})-(\d{2})-(\d{2})(?:\.|$)', filename)
if match: if match:
@@ -129,7 +128,7 @@ def parse_date_from_filename(filename: str) -> Optional[str]:
date_str = f"{year}-{month}-{day}" date_str = f"{year}-{month}-{day}"
debug.print(f"Parsed date (dashed): {date_str}") debug.print(f"Parsed date (dashed): {date_str}")
return date_str return date_str
debug.print(f"No date found in filename: {filename}") debug.print(f"No date found in filename: {filename}")
return None return None
@@ -137,20 +136,20 @@ def parse_date_from_filename(filename: str) -> Optional[str]:
def get_day_from_date(date_str: str) -> str: def get_day_from_date(date_str: str) -> str:
""" """
Get day of week name from date string. Get day of week name from date string.
Args: Args:
date_str: Date string (YYYY-MM-DD) date_str: Date string (YYYY-MM-DD)
Returns: Returns:
Day name (e.g., "Monday") Day name (e.g., "Monday")
Example: Example:
>>> get_day_from_date("2023-12-25") >>> get_day_from_date("2023-12-25")
"Monday" "Monday"
""" """
debug.validate(len(date_str) == 10 and date_str[4] == '-' and date_str[7] == '-', debug.validate(len(date_str) == 10 and date_str[4] == '-' and date_str[7] == '-',
f"Invalid date format: {date_str}, expected YYYY-MM-DD") f"Invalid date format: {date_str}, expected YYYY-MM-DD")
try: try:
year, month, day = map(int, date_str.split('-')) year, month, day = map(int, date_str.split('-'))
d = date(year, month, day) d = date(year, month, day)
@@ -165,10 +164,10 @@ def get_day_from_date(date_str: str) -> str:
def get_today_date() -> str: def get_today_date() -> str:
""" """
Get today's date as YYYY-MM-DD. Get today's date as YYYY-MM-DD.
Returns: Returns:
Today's date string Today's date string
Example: Example:
>>> get_today_date() >>> get_today_date()
"2023-12-25" "2023-12-25"
@@ -181,10 +180,10 @@ def get_today_date() -> str:
def get_today_day() -> str: def get_today_day() -> str:
""" """
Get today's day name. Get today's day name.
Returns: Returns:
Today's day name Today's day name
Example: Example:
>>> get_today_day() >>> get_today_day()
"Monday" "Monday"
@@ -197,43 +196,43 @@ def get_today_day() -> str:
class SecureDeleter: class SecureDeleter:
""" """
Securely delete files by overwriting with random data. Securely delete files by overwriting with random data.
Implements multi-pass overwriting before deletion. Implements multi-pass overwriting before deletion.
Example: Example:
>>> deleter = SecureDeleter("secret.txt", passes=3) >>> deleter = SecureDeleter("secret.txt", passes=3)
>>> deleter.execute() >>> deleter.execute()
""" """
def __init__(self, path: Union[str, Path], passes: int = 7): def __init__(self, path: str | Path, passes: int = 7):
""" """
Initialize secure deleter. Initialize secure deleter.
Args: Args:
path: Path to file or directory path: Path to file or directory
passes: Number of overwrite passes passes: Number of overwrite passes
""" """
debug.validate(passes > 0, f"Passes must be positive, got {passes}") debug.validate(passes > 0, f"Passes must be positive, got {passes}")
self.path = Path(path) self.path = Path(path)
self.passes = passes self.passes = passes
debug.print(f"SecureDeleter initialized for {self.path} with {passes} passes") debug.print(f"SecureDeleter initialized for {self.path} with {passes} passes")
def _overwrite_file(self, file_path: Path) -> None: def _overwrite_file(self, file_path: Path) -> None:
"""Overwrite file with random data multiple times.""" """Overwrite file with random data multiple times."""
if not file_path.exists() or not file_path.is_file(): if not file_path.exists() or not file_path.is_file():
debug.print(f"File does not exist or is not a file: {file_path}") debug.print(f"File does not exist or is not a file: {file_path}")
return return
length = file_path.stat().st_size length = file_path.stat().st_size
debug.print(f"Overwriting file {file_path} ({length} bytes)") debug.print(f"Overwriting file {file_path} ({length} bytes)")
if length == 0: if length == 0:
debug.print("File is empty, nothing to overwrite") debug.print("File is empty, nothing to overwrite")
return return
patterns = [b'\x00', b'\xFF', bytes([random.randint(0, 255)])] patterns = [b'\x00', b'\xFF', bytes([random.randint(0, 255)])]
for pass_num in range(self.passes): for pass_num in range(self.passes):
debug.print(f"Overwrite pass {pass_num + 1}/{self.passes}") debug.print(f"Overwrite pass {pass_num + 1}/{self.passes}")
with open(file_path, 'r+b') as f: with open(file_path, 'r+b') as f:
@@ -245,13 +244,13 @@ class SecureDeleter:
chunk = min(chunk_size, length - offset) chunk = min(chunk_size, length - offset)
f.write(pattern * (chunk // len(pattern))) f.write(pattern * (chunk // len(pattern)))
f.write(pattern[:chunk % len(pattern)]) f.write(pattern[:chunk % len(pattern)])
# Final pass with random data # Final pass with random data
f.seek(0) f.seek(0)
f.write(os.urandom(length)) f.write(os.urandom(length))
debug.print(f"Completed {self.passes} overwrite passes") debug.print(f"Completed {self.passes} overwrite passes")
def delete_file(self) -> None: def delete_file(self) -> None:
"""Securely delete a single file.""" """Securely delete a single file."""
if self.path.is_file(): if self.path.is_file():
@@ -261,28 +260,28 @@ class SecureDeleter:
debug.print(f"File deleted: {self.path}") debug.print(f"File deleted: {self.path}")
else: else:
debug.print(f"Not a file: {self.path}") debug.print(f"Not a file: {self.path}")
def delete_directory(self) -> None: def delete_directory(self) -> None:
"""Securely delete a directory and all contents.""" """Securely delete a directory and all contents."""
if not self.path.is_dir(): if not self.path.is_dir():
debug.print(f"Not a directory: {self.path}") debug.print(f"Not a directory: {self.path}")
return return
debug.print(f"Securely deleting directory: {self.path}") debug.print(f"Securely deleting directory: {self.path}")
# First, securely overwrite all files # First, securely overwrite all files
file_count = 0 file_count = 0
for file_path in self.path.rglob('*'): for file_path in self.path.rglob('*'):
if file_path.is_file(): if file_path.is_file():
self._overwrite_file(file_path) self._overwrite_file(file_path)
file_count += 1 file_count += 1
debug.print(f"Overwrote {file_count} files") debug.print(f"Overwrote {file_count} files")
# Then remove the directory tree # Then remove the directory tree
shutil.rmtree(self.path) shutil.rmtree(self.path)
debug.print(f"Directory deleted: {self.path}") debug.print(f"Directory deleted: {self.path}")
def execute(self) -> None: def execute(self) -> None:
"""Securely delete the path (file or directory).""" """Securely delete the path (file or directory)."""
debug.print(f"Executing secure deletion: {self.path}") debug.print(f"Executing secure deletion: {self.path}")
@@ -294,14 +293,14 @@ class SecureDeleter:
debug.print(f"Path does not exist: {self.path}") debug.print(f"Path does not exist: {self.path}")
def secure_delete(path: Union[str, Path], passes: int = 7) -> None: def secure_delete(path: str | Path, passes: int = 7) -> None:
""" """
Convenience function for secure deletion. Convenience function for secure deletion.
Args: Args:
path: Path to file or directory path: Path to file or directory
passes: Number of overwrite passes passes: Number of overwrite passes
Example: Example:
>>> secure_delete("secret.txt", passes=3) >>> secure_delete("secret.txt", passes=3)
""" """
@@ -312,19 +311,19 @@ def secure_delete(path: Union[str, Path], passes: int = 7) -> None:
def format_file_size(size_bytes: int) -> str: def format_file_size(size_bytes: int) -> str:
""" """
Format file size for display. Format file size for display.
Args: Args:
size_bytes: Size in bytes size_bytes: Size in bytes
Returns: Returns:
Human-readable string (e.g., "1.5 MB") Human-readable string (e.g., "1.5 MB")
Example: Example:
>>> format_file_size(1500000) >>> format_file_size(1500000)
"1.5 MB" "1.5 MB"
""" """
debug.validate(size_bytes >= 0, f"File size cannot be negative: {size_bytes}") debug.validate(size_bytes >= 0, f"File size cannot be negative: {size_bytes}")
size: float = float(size_bytes) size: float = float(size_bytes)
for unit in ['B', 'KB', 'MB', 'GB']: for unit in ['B', 'KB', 'MB', 'GB']:
if size < 1024: if size < 1024:
@@ -338,13 +337,13 @@ def format_file_size(size_bytes: int) -> str:
def format_number(n: int) -> str: def format_number(n: int) -> str:
""" """
Format number with commas. Format number with commas.
Args: Args:
n: Integer to format n: Integer to format
Returns: Returns:
Formatted string Formatted string
Example: Example:
>>> format_number(1234567) >>> format_number(1234567)
"1,234,567" "1,234,567"
@@ -356,15 +355,15 @@ def format_number(n: int) -> str:
def clamp(value: int, min_val: int, max_val: int) -> int: def clamp(value: int, min_val: int, max_val: int) -> int:
""" """
Clamp value to range. Clamp value to range.
Args: Args:
value: Value to clamp value: Value to clamp
min_val: Minimum allowed value min_val: Minimum allowed value
max_val: Maximum allowed value max_val: Maximum allowed value
Returns: Returns:
Clamped value Clamped value
Example: Example:
>>> clamp(15, 0, 10) >>> clamp(15, 0, 10)
10 10

View File

@@ -10,40 +10,50 @@ Changes in v3.2.0:
""" """
import io import io
from typing import Optional, Union
from PIL import Image from PIL import Image
from .constants import ( from .constants import (
MIN_PIN_LENGTH, MAX_PIN_LENGTH, ALLOWED_IMAGE_EXTENSIONS,
MAX_MESSAGE_SIZE, MAX_FILE_PAYLOAD_SIZE, MAX_IMAGE_PIXELS, MAX_FILE_SIZE, ALLOWED_KEY_EXTENSIONS,
MIN_RSA_BITS, MIN_KEY_PASSWORD_LENGTH, EMBED_MODE_AUTO,
ALLOWED_IMAGE_EXTENSIONS, ALLOWED_KEY_EXTENSIONS, EMBED_MODE_DCT,
MIN_PASSPHRASE_WORDS, RECOMMENDED_PASSPHRASE_WORDS, EMBED_MODE_LSB,
EMBED_MODE_LSB, EMBED_MODE_DCT, EMBED_MODE_AUTO, MAX_FILE_PAYLOAD_SIZE,
MAX_FILE_SIZE,
MAX_IMAGE_PIXELS,
MAX_MESSAGE_SIZE,
MAX_PIN_LENGTH,
MIN_KEY_PASSWORD_LENGTH,
MIN_PASSPHRASE_WORDS,
MIN_PIN_LENGTH,
MIN_RSA_BITS,
RECOMMENDED_PASSPHRASE_WORDS,
) )
from .models import ValidationResult, FilePayload
from .exceptions import ( from .exceptions import (
ValidationError, PinValidationError, MessageValidationError, ImageValidationError,
ImageValidationError, KeyValidationError, SecurityFactorError, KeyValidationError,
FileTooLargeError, UnsupportedFileTypeError, MessageValidationError,
PinValidationError,
SecurityFactorError,
) )
from .keygen import load_rsa_key from .keygen import load_rsa_key
from .models import FilePayload, ValidationResult
def validate_pin(pin: str, required: bool = False) -> ValidationResult: def validate_pin(pin: str, required: bool = False) -> ValidationResult:
""" """
Validate PIN format. Validate PIN format.
Rules: Rules:
- 6-9 digits only - 6-9 digits only
- Cannot start with zero - Cannot start with zero
- Empty is OK if not required - Empty is OK if not required
Args: Args:
pin: PIN string to validate pin: PIN string to validate
required: Whether PIN is required required: Whether PIN is required
Returns: Returns:
ValidationResult ValidationResult
""" """
@@ -51,83 +61,83 @@ def validate_pin(pin: str, required: bool = False) -> ValidationResult:
if required: if required:
return ValidationResult.error("PIN is required") return ValidationResult.error("PIN is required")
return ValidationResult.ok() return ValidationResult.ok()
if not pin.isdigit(): if not pin.isdigit():
return ValidationResult.error("PIN must contain only digits") return ValidationResult.error("PIN must contain only digits")
if len(pin) < MIN_PIN_LENGTH or len(pin) > MAX_PIN_LENGTH: if len(pin) < MIN_PIN_LENGTH or len(pin) > MAX_PIN_LENGTH:
return ValidationResult.error( return ValidationResult.error(
f"PIN must be {MIN_PIN_LENGTH}-{MAX_PIN_LENGTH} digits" f"PIN must be {MIN_PIN_LENGTH}-{MAX_PIN_LENGTH} digits"
) )
if pin[0] == '0': if pin[0] == '0':
return ValidationResult.error("PIN cannot start with zero") return ValidationResult.error("PIN cannot start with zero")
return ValidationResult.ok(length=len(pin)) return ValidationResult.ok(length=len(pin))
def validate_message(message: str) -> ValidationResult: def validate_message(message: str) -> ValidationResult:
""" """
Validate text message content and size. Validate text message content and size.
Args: Args:
message: Message text message: Message text
Returns: Returns:
ValidationResult ValidationResult
""" """
if not message: if not message:
return ValidationResult.error("Message is required") return ValidationResult.error("Message is required")
if len(message) > MAX_MESSAGE_SIZE: if len(message) > MAX_MESSAGE_SIZE:
return ValidationResult.error( return ValidationResult.error(
f"Message too long ({len(message):,} chars). Maximum: {MAX_MESSAGE_SIZE:,} characters" f"Message too long ({len(message):,} chars). Maximum: {MAX_MESSAGE_SIZE:,} characters"
) )
return ValidationResult.ok(length=len(message)) return ValidationResult.ok(length=len(message))
def validate_payload(payload: Union[str, bytes, FilePayload]) -> ValidationResult: def validate_payload(payload: str | bytes | FilePayload) -> ValidationResult:
""" """
Validate a payload (text message, bytes, or file). Validate a payload (text message, bytes, or file).
Args: Args:
payload: Text string, raw bytes, or FilePayload payload: Text string, raw bytes, or FilePayload
Returns: Returns:
ValidationResult ValidationResult
""" """
if isinstance(payload, str): if isinstance(payload, str):
return validate_message(payload) return validate_message(payload)
elif isinstance(payload, FilePayload): elif isinstance(payload, FilePayload):
if not payload.data: if not payload.data:
return ValidationResult.error("File is empty") return ValidationResult.error("File is empty")
if len(payload.data) > MAX_FILE_PAYLOAD_SIZE: if len(payload.data) > MAX_FILE_PAYLOAD_SIZE:
return ValidationResult.error( return ValidationResult.error(
f"File too large ({len(payload.data):,} bytes). " f"File too large ({len(payload.data):,} bytes). "
f"Maximum: {MAX_FILE_PAYLOAD_SIZE:,} bytes ({MAX_FILE_PAYLOAD_SIZE // 1024} KB)" f"Maximum: {MAX_FILE_PAYLOAD_SIZE:,} bytes ({MAX_FILE_PAYLOAD_SIZE // 1024} KB)"
) )
return ValidationResult.ok( return ValidationResult.ok(
size=len(payload.data), size=len(payload.data),
filename=payload.filename, filename=payload.filename,
mime_type=payload.mime_type mime_type=payload.mime_type
) )
elif isinstance(payload, bytes): elif isinstance(payload, bytes):
if not payload: if not payload:
return ValidationResult.error("Payload is empty") return ValidationResult.error("Payload is empty")
if len(payload) > MAX_FILE_PAYLOAD_SIZE: if len(payload) > MAX_FILE_PAYLOAD_SIZE:
return ValidationResult.error( return ValidationResult.error(
f"Payload too large ({len(payload):,} bytes). " f"Payload too large ({len(payload):,} bytes). "
f"Maximum: {MAX_FILE_PAYLOAD_SIZE:,} bytes ({MAX_FILE_PAYLOAD_SIZE // 1024} KB)" f"Maximum: {MAX_FILE_PAYLOAD_SIZE:,} bytes ({MAX_FILE_PAYLOAD_SIZE // 1024} KB)"
) )
return ValidationResult.ok(size=len(payload)) return ValidationResult.ok(size=len(payload))
else: else:
return ValidationResult.error(f"Invalid payload type: {type(payload)}") return ValidationResult.error(f"Invalid payload type: {type(payload)}")
@@ -139,18 +149,18 @@ def validate_file_payload(
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate a file for embedding. Validate a file for embedding.
Args: Args:
file_data: Raw file bytes file_data: Raw file bytes
filename: Original filename (for display in errors) filename: Original filename (for display in errors)
max_size: Maximum allowed size in bytes max_size: Maximum allowed size in bytes
Returns: Returns:
ValidationResult ValidationResult
""" """
if not file_data: if not file_data:
return ValidationResult.error("File is empty") return ValidationResult.error("File is empty")
if len(file_data) > max_size: if len(file_data) > max_size:
size_kb = len(file_data) / 1024 size_kb = len(file_data) / 1024
max_kb = max_size / 1024 max_kb = max_size / 1024
@@ -158,7 +168,7 @@ def validate_file_payload(
f"File '{filename or 'unnamed'}' too large ({size_kb:.1f} KB). " f"File '{filename or 'unnamed'}' too large ({size_kb:.1f} KB). "
f"Maximum: {max_kb:.0f} KB" f"Maximum: {max_kb:.0f} KB"
) )
return ValidationResult.ok(size=len(file_data), filename=filename) return ValidationResult.ok(size=len(file_data), filename=filename)
@@ -169,35 +179,35 @@ def validate_image(
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate image data and dimensions. Validate image data and dimensions.
Args: Args:
image_data: Raw image bytes image_data: Raw image bytes
name: Name for error messages name: Name for error messages
check_size: Whether to check pixel dimensions check_size: Whether to check pixel dimensions
Returns: Returns:
ValidationResult with width, height, pixels ValidationResult with width, height, pixels
""" """
if not image_data: if not image_data:
return ValidationResult.error(f"{name} is required") return ValidationResult.error(f"{name} is required")
if len(image_data) > MAX_FILE_SIZE: if len(image_data) > MAX_FILE_SIZE:
return ValidationResult.error( return ValidationResult.error(
f"{name} too large ({len(image_data):,} bytes). Maximum: {MAX_FILE_SIZE:,} bytes" f"{name} too large ({len(image_data):,} bytes). Maximum: {MAX_FILE_SIZE:,} bytes"
) )
try: try:
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
width, height = img.size width, height = img.size
num_pixels = width * height num_pixels = width * height
if check_size and num_pixels > MAX_IMAGE_PIXELS: if check_size and num_pixels > MAX_IMAGE_PIXELS:
max_dim = int(MAX_IMAGE_PIXELS ** 0.5) max_dim = int(MAX_IMAGE_PIXELS ** 0.5)
return ValidationResult.error( return ValidationResult.error(
f"{name} too large ({width}×{height} = {num_pixels:,} pixels). " f"{name} too large ({width}×{height} = {num_pixels:,} pixels). "
f"Maximum: ~{MAX_IMAGE_PIXELS:,} pixels ({max_dim}×{max_dim})" f"Maximum: ~{MAX_IMAGE_PIXELS:,} pixels ({max_dim}×{max_dim})"
) )
return ValidationResult.ok( return ValidationResult.ok(
width=width, width=width,
height=height, height=height,
@@ -205,24 +215,24 @@ def validate_image(
mode=img.mode, mode=img.mode,
format=img.format format=img.format
) )
except Exception as e: except Exception as e:
return ValidationResult.error(f"Could not read {name}: {e}") return ValidationResult.error(f"Could not read {name}: {e}")
def validate_rsa_key( def validate_rsa_key(
key_data: bytes, key_data: bytes,
password: Optional[str] = None, password: str | None = None,
required: bool = False required: bool = False
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate RSA private key. Validate RSA private key.
Args: Args:
key_data: PEM-encoded key bytes key_data: PEM-encoded key bytes
password: Password if key is encrypted password: Password if key is encrypted
required: Whether key is required required: Whether key is required
Returns: Returns:
ValidationResult with key_size ValidationResult with key_size
""" """
@@ -230,44 +240,44 @@ def validate_rsa_key(
if required: if required:
return ValidationResult.error("RSA key is required") return ValidationResult.error("RSA key is required")
return ValidationResult.ok() return ValidationResult.ok()
try: try:
private_key = load_rsa_key(key_data, password) private_key = load_rsa_key(key_data, password)
key_size = private_key.key_size key_size = private_key.key_size
if key_size < MIN_RSA_BITS: if key_size < MIN_RSA_BITS:
return ValidationResult.error( return ValidationResult.error(
f"RSA key must be at least {MIN_RSA_BITS} bits (got {key_size})" f"RSA key must be at least {MIN_RSA_BITS} bits (got {key_size})"
) )
return ValidationResult.ok(key_size=key_size) return ValidationResult.ok(key_size=key_size)
except Exception as e: except Exception as e:
return ValidationResult.error(str(e)) return ValidationResult.error(str(e))
def validate_security_factors( def validate_security_factors(
pin: str, pin: str,
rsa_key_data: Optional[bytes] rsa_key_data: bytes | None
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate that at least one security factor is provided. Validate that at least one security factor is provided.
Args: Args:
pin: PIN string (may be empty) pin: PIN string (may be empty)
rsa_key_data: RSA key bytes (may be None/empty) rsa_key_data: RSA key bytes (may be None/empty)
Returns: Returns:
ValidationResult ValidationResult
""" """
has_pin = bool(pin and pin.strip()) has_pin = bool(pin and pin.strip())
has_key = bool(rsa_key_data and len(rsa_key_data) > 0) has_key = bool(rsa_key_data and len(rsa_key_data) > 0)
if not has_pin and not has_key: if not has_pin and not has_key:
return ValidationResult.error( return ValidationResult.error(
"You must provide at least a PIN or RSA Key" "You must provide at least a PIN or RSA Key"
) )
return ValidationResult.ok(has_pin=has_pin, has_key=has_key) return ValidationResult.ok(has_pin=has_pin, has_key=has_key)
@@ -278,26 +288,26 @@ def validate_file_extension(
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate file extension. Validate file extension.
Args: Args:
filename: Filename to check filename: Filename to check
allowed: Set of allowed extensions (lowercase, no dot) allowed: Set of allowed extensions (lowercase, no dot)
file_type: Name for error messages file_type: Name for error messages
Returns: Returns:
ValidationResult with extension ValidationResult with extension
""" """
if not filename or '.' not in filename: if not filename or '.' not in filename:
return ValidationResult.error(f"{file_type} must have a file extension") return ValidationResult.error(f"{file_type} must have a file extension")
ext = filename.rsplit('.', 1)[1].lower() ext = filename.rsplit('.', 1)[1].lower()
if ext not in allowed: if ext not in allowed:
return ValidationResult.error( return ValidationResult.error(
f"Unsupported {file_type.lower()} type: .{ext}. " f"Unsupported {file_type.lower()} type: .{ext}. "
f"Allowed: {', '.join(sorted('.' + e for e in allowed))}" f"Allowed: {', '.join(sorted('.' + e for e in allowed))}"
) )
return ValidationResult.ok(extension=ext) return ValidationResult.ok(extension=ext)
@@ -314,53 +324,53 @@ def validate_key_file(filename: str) -> ValidationResult:
def validate_key_password(password: str) -> ValidationResult: def validate_key_password(password: str) -> ValidationResult:
""" """
Validate password for key encryption. Validate password for key encryption.
Args: Args:
password: Password string password: Password string
Returns: Returns:
ValidationResult ValidationResult
""" """
if not password: if not password:
return ValidationResult.error("Password is required") return ValidationResult.error("Password is required")
if len(password) < MIN_KEY_PASSWORD_LENGTH: if len(password) < MIN_KEY_PASSWORD_LENGTH:
return ValidationResult.error( return ValidationResult.error(
f"Password must be at least {MIN_KEY_PASSWORD_LENGTH} characters" f"Password must be at least {MIN_KEY_PASSWORD_LENGTH} characters"
) )
return ValidationResult.ok(length=len(password)) return ValidationResult.ok(length=len(password))
def validate_passphrase(passphrase: str) -> ValidationResult: def validate_passphrase(passphrase: str) -> ValidationResult:
""" """
Validate passphrase. Validate passphrase.
v3.2.0: Recommend 4+ words for good entropy (since date is no longer used). v3.2.0: Recommend 4+ words for good entropy (since date is no longer used).
Args: Args:
passphrase: Passphrase string passphrase: Passphrase string
Returns: Returns:
ValidationResult with word_count and optional warning ValidationResult with word_count and optional warning
""" """
if not passphrase or not passphrase.strip(): if not passphrase or not passphrase.strip():
return ValidationResult.error("Passphrase is required") return ValidationResult.error("Passphrase is required")
words = passphrase.strip().split() words = passphrase.strip().split()
if len(words) < MIN_PASSPHRASE_WORDS: if len(words) < MIN_PASSPHRASE_WORDS:
return ValidationResult.error( return ValidationResult.error(
f"Passphrase should have at least {MIN_PASSPHRASE_WORDS} words" f"Passphrase should have at least {MIN_PASSPHRASE_WORDS} words"
) )
# Provide warning if below recommended length # Provide warning if below recommended length
if len(words) < RECOMMENDED_PASSPHRASE_WORDS: if len(words) < RECOMMENDED_PASSPHRASE_WORDS:
return ValidationResult.ok( return ValidationResult.ok(
word_count=len(words), word_count=len(words),
warning=f"Recommend {RECOMMENDED_PASSPHRASE_WORDS}+ words for better security" warning=f"Recommend {RECOMMENDED_PASSPHRASE_WORDS}+ words for better security"
) )
return ValidationResult.ok(word_count=len(words)) return ValidationResult.ok(word_count=len(words))
@@ -381,60 +391,60 @@ def validate_carrier(carrier_data: bytes) -> ValidationResult:
def validate_embed_mode(mode: str) -> ValidationResult: def validate_embed_mode(mode: str) -> ValidationResult:
""" """
Validate embedding mode. Validate embedding mode.
Args: Args:
mode: Embedding mode string mode: Embedding mode string
Returns: Returns:
ValidationResult ValidationResult
""" """
valid_modes = {EMBED_MODE_LSB, EMBED_MODE_DCT, EMBED_MODE_AUTO} valid_modes = {EMBED_MODE_LSB, EMBED_MODE_DCT, EMBED_MODE_AUTO}
if mode not in valid_modes: if mode not in valid_modes:
return ValidationResult.error( return ValidationResult.error(
f"Invalid embed_mode: '{mode}'. Valid options: {', '.join(sorted(valid_modes))}" f"Invalid embed_mode: '{mode}'. Valid options: {', '.join(sorted(valid_modes))}"
) )
return ValidationResult.ok(mode=mode) return ValidationResult.ok(mode=mode)
def validate_dct_output_format(format_str: str) -> ValidationResult: def validate_dct_output_format(format_str: str) -> ValidationResult:
""" """
Validate DCT output format. Validate DCT output format.
Args: Args:
format_str: Output format ('png' or 'jpeg') format_str: Output format ('png' or 'jpeg')
Returns: Returns:
ValidationResult ValidationResult
""" """
valid_formats = {'png', 'jpeg'} valid_formats = {'png', 'jpeg'}
if format_str.lower() not in valid_formats: if format_str.lower() not in valid_formats:
return ValidationResult.error( return ValidationResult.error(
f"Invalid DCT output format: '{format_str}'. Valid options: {', '.join(sorted(valid_formats))}" f"Invalid DCT output format: '{format_str}'. Valid options: {', '.join(sorted(valid_formats))}"
) )
return ValidationResult.ok(format=format_str.lower()) return ValidationResult.ok(format=format_str.lower())
def validate_dct_color_mode(mode: str) -> ValidationResult: def validate_dct_color_mode(mode: str) -> ValidationResult:
""" """
Validate DCT color mode. Validate DCT color mode.
Args: Args:
mode: Color mode ('grayscale' or 'color') mode: Color mode ('grayscale' or 'color')
Returns: Returns:
ValidationResult ValidationResult
""" """
valid_modes = {'grayscale', 'color'} valid_modes = {'grayscale', 'color'}
if mode.lower() not in valid_modes: if mode.lower() not in valid_modes:
return ValidationResult.error( return ValidationResult.error(
f"Invalid DCT color mode: '{mode}'. Valid options: {', '.join(sorted(valid_modes))}" f"Invalid DCT color mode: '{mode}'. Valid options: {', '.join(sorted(valid_modes))}"
) )
return ValidationResult.ok(mode=mode.lower()) return ValidationResult.ok(mode=mode.lower())
@@ -456,7 +466,7 @@ def require_valid_message(message: str) -> None:
raise MessageValidationError(result.error_message) raise MessageValidationError(result.error_message)
def require_valid_payload(payload: Union[str, bytes, FilePayload]) -> None: def require_valid_payload(payload: str | bytes | FilePayload) -> None:
"""Validate payload (text, bytes, or file), raising exception on failure.""" """Validate payload (text, bytes, or file), raising exception on failure."""
result = validate_payload(payload) result = validate_payload(payload)
if not result.is_valid: if not result.is_valid:
@@ -472,7 +482,7 @@ def require_valid_image(image_data: bytes, name: str = "Image") -> None:
def require_valid_rsa_key( def require_valid_rsa_key(
key_data: bytes, key_data: bytes,
password: Optional[str] = None, password: str | None = None,
required: bool = False required: bool = False
) -> None: ) -> None:
"""Validate RSA key, raising exception on failure.""" """Validate RSA key, raising exception on failure."""
@@ -481,7 +491,7 @@ def require_valid_rsa_key(
raise KeyValidationError(result.error_message) raise KeyValidationError(result.error_message)
def require_security_factors(pin: str, rsa_key_data: Optional[bytes]) -> None: def require_security_factors(pin: str, rsa_key_data: bytes | None) -> None:
"""Validate security factors, raising exception on failure.""" """Validate security factors, raising exception on failure."""
result = validate_security_factors(pin, rsa_key_data) result = validate_security_factors(pin, rsa_key_data)
if not result.is_valid: if not result.is_valid:

View File

@@ -131,36 +131,36 @@ if HAS_JPEGIO:
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("JPEGIO SPECIFIC TEST") print("JPEGIO SPECIFIC TEST")
print("=" * 60) print("=" * 60)
import tempfile import tempfile
import os import os
# Reload image data # Reload image data
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
carrier_data = f.read() carrier_data = f.read()
print("\n[J1] Checking if image is JPEG...") print("\n[J1] Checking if image is JPEG...")
img = Image.open(io.BytesIO(carrier_data)) img = Image.open(io.BytesIO(carrier_data))
is_jpeg = img.format == 'JPEG' is_jpeg = img.format == 'JPEG'
img.close() img.close()
print(f" Is JPEG: {is_jpeg}") print(f" Is JPEG: {is_jpeg}")
if is_jpeg: if is_jpeg:
print("\n[J2] Writing to temp file...") print("\n[J2] Writing to temp file...")
fd, temp_path = tempfile.mkstemp(suffix='.jpg') fd, temp_path = tempfile.mkstemp(suffix='.jpg')
os.write(fd, carrier_data) os.write(fd, carrier_data)
os.close(fd) os.close(fd)
print(f" Temp file: {temp_path}") print(f" Temp file: {temp_path}")
print("\n[J3] Reading with jpegio...") print("\n[J3] Reading with jpegio...")
try: try:
jpeg = jio.read(temp_path) jpeg = jio.read(temp_path)
print(f" jpegio.read() OK") print(f" jpegio.read() OK")
print("\n[J4] Accessing coefficient arrays...") print("\n[J4] Accessing coefficient arrays...")
coef = jpeg.coef_arrays[0] coef = jpeg.coef_arrays[0]
print(f" Coef shape: {coef.shape}, dtype: {coef.dtype}") print(f" Coef shape: {coef.shape}, dtype: {coef.dtype}")
print("\n[J5] Counting usable positions...") print("\n[J5] Counting usable positions...")
positions = [] positions = []
h, w = coef.shape h, w = coef.shape
@@ -171,31 +171,31 @@ if HAS_JPEGIO:
if abs(coef[row, col]) >= 2: if abs(coef[row, col]) >= 2:
positions.append((row, col)) positions.append((row, col))
print(f" Usable positions: {len(positions)}") print(f" Usable positions: {len(positions)}")
print("\n[J6] Cleaning up jpegio object...") print("\n[J6] Cleaning up jpegio object...")
del coef del coef
del jpeg del jpeg
gc.collect() gc.collect()
print(" Deleted jpeg object") print(" Deleted jpeg object")
print("\n[J7] Removing temp file...") print("\n[J7] Removing temp file...")
os.unlink(temp_path) os.unlink(temp_path)
print(" Temp file removed") print(" Temp file removed")
gc.collect() gc.collect()
print("\n[J8] Final GC...") print("\n[J8] Final GC...")
except Exception as e: except Exception as e:
print(f" ERROR: {e}") print(f" ERROR: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
print("\n[J9] Waiting for delayed crash...") print("\n[J9] Waiting for delayed crash...")
for i in range(3): for i in range(3):
time.sleep(1) time.sleep(1)
print(f" {i+1}s...") print(f" {i+1}s...")
gc.collect() gc.collect()
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("JPEGIO TEST PASSED - No crash detected") print("JPEGIO TEST PASSED - No crash detected")
print("=" * 60) print("=" * 60)

View File

@@ -61,16 +61,16 @@ except ImportError:
print("\n[3] BASIC DCT TEST (8x8 block)") print("\n[3] BASIC DCT TEST (8x8 block)")
try: try:
test_block = np.random.rand(8, 8).astype(np.float64) test_block = np.random.rand(8, 8).astype(np.float64)
# 1D DCT on rows # 1D DCT on rows
result = dct(test_block[0, :], norm='ortho') result = dct(test_block[0, :], norm='ortho')
print(f" 1D DCT: OK (output shape: {result.shape})") print(f" 1D DCT: OK (output shape: {result.shape})")
# 1D IDCT # 1D IDCT
recovered = idct(result, norm='ortho') recovered = idct(result, norm='ortho')
error = np.max(np.abs(test_block[0, :] - recovered)) error = np.max(np.abs(test_block[0, :] - recovered))
print(f" 1D IDCT: OK (roundtrip error: {error:.2e})") print(f" 1D IDCT: OK (roundtrip error: {error:.2e})")
# 2D via separable # 2D via separable
temp = np.zeros_like(test_block) temp = np.zeros_like(test_block)
for i in range(8): for i in range(8):
@@ -79,10 +79,10 @@ try:
for i in range(8): for i in range(8):
result2d[i, :] = dct(temp[i, :], norm='ortho') result2d[i, :] = dct(temp[i, :], norm='ortho')
print(f" 2D DCT: OK") print(f" 2D DCT: OK")
gc.collect() gc.collect()
print(" GC after basic test: OK") print(" GC after basic test: OK")
except Exception as e: except Exception as e:
print(f" FAILED: {e}") print(f" FAILED: {e}")
traceback.print_exc() traceback.print_exc()
@@ -92,10 +92,10 @@ print("\n[4] STRESS TEST (many 8x8 blocks)")
try: try:
NUM_BLOCKS = 10000 NUM_BLOCKS = 10000
print(f" Processing {NUM_BLOCKS} blocks...") print(f" Processing {NUM_BLOCKS} blocks...")
for i in range(NUM_BLOCKS): for i in range(NUM_BLOCKS):
block = np.random.rand(8, 8).astype(np.float64) block = np.random.rand(8, 8).astype(np.float64)
# Forward DCT # Forward DCT
temp = np.zeros_like(block) temp = np.zeros_like(block)
for j in range(8): for j in range(8):
@@ -103,7 +103,7 @@ try:
result = np.zeros_like(temp) result = np.zeros_like(temp)
for j in range(8): for j in range(8):
result[j, :] = dct(temp[j, :], norm='ortho') result[j, :] = dct(temp[j, :], norm='ortho')
# Inverse DCT # Inverse DCT
temp2 = np.zeros_like(result) temp2 = np.zeros_like(result)
for j in range(8): for j in range(8):
@@ -111,14 +111,14 @@ try:
recovered = np.zeros_like(temp2) recovered = np.zeros_like(temp2)
for j in range(8): for j in range(8):
recovered[:, j] = idct(temp2[:, j], norm='ortho') recovered[:, j] = idct(temp2[:, j], norm='ortho')
if i % 1000 == 0: if i % 1000 == 0:
gc.collect() gc.collect()
print(f" {i}/{NUM_BLOCKS} blocks processed...") print(f" {i}/{NUM_BLOCKS} blocks processed...")
gc.collect() gc.collect()
print(f" Stress test PASSED") print(f" Stress test PASSED")
except Exception as e: except Exception as e:
print(f" FAILED at block {i}: {e}") print(f" FAILED at block {i}: {e}")
traceback.print_exc() traceback.print_exc()
@@ -127,18 +127,18 @@ except Exception as e:
if len(sys.argv) > 1: if len(sys.argv) > 1:
image_path = sys.argv[1] image_path = sys.argv[1]
print(f"\n[5] IMAGE TEST: {image_path}") print(f"\n[5] IMAGE TEST: {image_path}")
try: try:
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
image_data = f.read() image_data = f.read()
print(f" File size: {len(image_data) / 1024 / 1024:.2f} MB") print(f" File size: {len(image_data) / 1024 / 1024:.2f} MB")
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
width, height = img.size width, height = img.size
print(f" Dimensions: {width}x{height}") print(f" Dimensions: {width}x{height}")
print(f" Format: {img.format}") print(f" Format: {img.format}")
print(f" Mode: {img.mode}") print(f" Mode: {img.mode}")
# Convert to grayscale float array # Convert to grayscale float array
gray = img.convert('L') gray = img.convert('L')
arr = np.array(gray, dtype=np.float64) arr = np.array(gray, dtype=np.float64)
@@ -146,35 +146,35 @@ if len(sys.argv) > 1:
gray.close() gray.close()
print(f" Array shape: {arr.shape}") print(f" Array shape: {arr.shape}")
print(f" Array dtype: {arr.dtype}") print(f" Array dtype: {arr.dtype}")
# Pad to block boundary # Pad to block boundary
BLOCK_SIZE = 8 BLOCK_SIZE = 8
h, w = arr.shape h, w = arr.shape
new_h = ((h + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE new_h = ((h + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
new_w = ((w + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE new_w = ((w + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
if new_h != h or new_w != w: if new_h != h or new_w != w:
padded = np.zeros((new_h, new_w), dtype=np.float64) padded = np.zeros((new_h, new_w), dtype=np.float64)
padded[:h, :w] = arr padded[:h, :w] = arr
arr = padded arr = padded
print(f" Padded to: {arr.shape}") print(f" Padded to: {arr.shape}")
blocks_y = arr.shape[0] // BLOCK_SIZE blocks_y = arr.shape[0] // BLOCK_SIZE
blocks_x = arr.shape[1] // BLOCK_SIZE blocks_x = arr.shape[1] // BLOCK_SIZE
total_blocks = blocks_y * blocks_x total_blocks = blocks_y * blocks_x
print(f" Total 8x8 blocks: {total_blocks}") print(f" Total 8x8 blocks: {total_blocks}")
# Process ALL blocks # Process ALL blocks
print(f" Processing all blocks with DCT...") print(f" Processing all blocks with DCT...")
processed = 0 processed = 0
for by in range(blocks_y): for by in range(blocks_y):
for bx in range(blocks_x): for bx in range(blocks_x):
y = by * BLOCK_SIZE y = by * BLOCK_SIZE
x = bx * BLOCK_SIZE x = bx * BLOCK_SIZE
block = arr[y:y+BLOCK_SIZE, x:x+BLOCK_SIZE].copy() block = arr[y:y+BLOCK_SIZE, x:x+BLOCK_SIZE].copy()
# Forward DCT # Forward DCT
temp = np.zeros((8, 8), dtype=np.float64) temp = np.zeros((8, 8), dtype=np.float64)
for i in range(8): for i in range(8):
@@ -182,7 +182,7 @@ if len(sys.argv) > 1:
dct_block = np.zeros((8, 8), dtype=np.float64) dct_block = np.zeros((8, 8), dtype=np.float64)
for i in range(8): for i in range(8):
dct_block[i, :] = dct(temp[i, :], norm='ortho') dct_block[i, :] = dct(temp[i, :], norm='ortho')
# Inverse DCT # Inverse DCT
temp2 = np.zeros((8, 8), dtype=np.float64) temp2 = np.zeros((8, 8), dtype=np.float64)
for i in range(8): for i in range(8):
@@ -190,17 +190,17 @@ if len(sys.argv) > 1:
recovered = np.zeros((8, 8), dtype=np.float64) recovered = np.zeros((8, 8), dtype=np.float64)
for i in range(8): for i in range(8):
recovered[:, i] = idct(temp2[:, i], norm='ortho') recovered[:, i] = idct(temp2[:, i], norm='ortho')
processed += 1 processed += 1
# GC after each row of blocks # GC after each row of blocks
if by % 50 == 0: if by % 50 == 0:
gc.collect() gc.collect()
print(f" Row {by}/{blocks_y} ({processed}/{total_blocks} blocks)") print(f" Row {by}/{blocks_y} ({processed}/{total_blocks} blocks)")
gc.collect() gc.collect()
print(f" Image DCT test PASSED ({processed} blocks)") print(f" Image DCT test PASSED ({processed} blocks)")
except Exception as e: except Exception as e:
print(f" FAILED: {e}") print(f" FAILED: {e}")
traceback.print_exc() traceback.print_exc()

View File

@@ -7,18 +7,19 @@ Updated for v4.0.0:
- BatchCredentials.passphrase is a single string - BatchCredentials.passphrase is a single string
""" """
import pytest
import tempfile
import shutil import shutil
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import Mock
import pytest
from stegasoo.batch import ( from stegasoo.batch import (
BatchCredentials,
BatchItem,
BatchProcessor, BatchProcessor,
BatchResult, BatchResult,
BatchItem,
BatchStatus, BatchStatus,
BatchCredentials,
batch_capacity_check, batch_capacity_check,
print_batch_result, print_batch_result,
) )
@@ -36,14 +37,14 @@ def temp_dir():
def sample_images(temp_dir): def sample_images(temp_dir):
"""Create sample PNG images for testing.""" """Create sample PNG images for testing."""
from PIL import Image from PIL import Image
images = [] images = []
for i in range(3): for i in range(3):
img_path = temp_dir / f"test_image_{i}.png" img_path = temp_dir / f"test_image_{i}.png"
img = Image.new('RGB', (100, 100), color=(i * 50, i * 50, i * 50)) img = Image.new('RGB', (100, 100), color=(i * 50, i * 50, i * 50))
img.save(img_path, 'PNG') img.save(img_path, 'PNG')
images.append(img_path) images.append(img_path)
return images return images
@@ -58,19 +59,19 @@ def sample_credentials():
class TestBatchItem: class TestBatchItem:
"""Tests for BatchItem dataclass.""" """Tests for BatchItem dataclass."""
def test_duration_calculation(self): def test_duration_calculation(self):
"""Duration should be calculated from start/end times.""" """Duration should be calculated from start/end times."""
item = BatchItem(input_path=Path("test.png")) item = BatchItem(input_path=Path("test.png"))
item.start_time = 100.0 item.start_time = 100.0
item.end_time = 105.5 item.end_time = 105.5
assert item.duration == 5.5 assert item.duration == 5.5
def test_duration_none_without_times(self): def test_duration_none_without_times(self):
"""Duration should be None if times not set.""" """Duration should be None if times not set."""
item = BatchItem(input_path=Path("test.png")) item = BatchItem(input_path=Path("test.png"))
assert item.duration is None assert item.duration is None
def test_to_dict(self): def test_to_dict(self):
"""to_dict should serialize all fields.""" """to_dict should serialize all fields."""
item = BatchItem( item = BatchItem(
@@ -87,7 +88,7 @@ class TestBatchItem:
class TestBatchResult: class TestBatchResult:
"""Tests for BatchResult dataclass.""" """Tests for BatchResult dataclass."""
def test_to_json(self): def test_to_json(self):
"""Should serialize to valid JSON.""" """Should serialize to valid JSON."""
import json import json
@@ -96,7 +97,7 @@ class TestBatchResult:
parsed = json.loads(json_str) parsed = json.loads(json_str)
assert parsed['operation'] == "encode" assert parsed['operation'] == "encode"
assert parsed['summary']['total'] == 5 assert parsed['summary']['total'] == 5
def test_duration_with_end_time(self): def test_duration_with_end_time(self):
"""Duration should work when end_time is set.""" """Duration should work when end_time is set."""
result = BatchResult(operation="test") result = BatchResult(operation="test")
@@ -107,7 +108,7 @@ class TestBatchResult:
class TestBatchCredentials: class TestBatchCredentials:
"""Tests for BatchCredentials dataclass (v3.2.0).""" """Tests for BatchCredentials dataclass (v3.2.0)."""
def test_from_dict_new_format(self): def test_from_dict_new_format(self):
"""Should parse v3.2.0 format with 'passphrase' key.""" """Should parse v3.2.0 format with 'passphrase' key."""
data = { data = {
@@ -117,7 +118,7 @@ class TestBatchCredentials:
creds = BatchCredentials.from_dict(data) creds = BatchCredentials.from_dict(data)
assert creds.passphrase == "test phrase four words" assert creds.passphrase == "test phrase four words"
assert creds.pin == "123456" assert creds.pin == "123456"
def test_from_dict_legacy_format(self): def test_from_dict_legacy_format(self):
"""Should parse legacy format with 'day_phrase' key for migration.""" """Should parse legacy format with 'day_phrase' key for migration."""
data = { data = {
@@ -128,7 +129,7 @@ class TestBatchCredentials:
# Should accept old key and map to passphrase # Should accept old key and map to passphrase
assert creds.passphrase == "legacy phrase here" assert creds.passphrase == "legacy phrase here"
assert creds.pin == "123456" assert creds.pin == "123456"
def test_to_dict(self): def test_to_dict(self):
"""Should serialize to v3.2.0 format.""" """Should serialize to v3.2.0 format."""
creds = BatchCredentials( creds = BatchCredentials(
@@ -139,7 +140,7 @@ class TestBatchCredentials:
assert result['passphrase'] == "test phrase four words" assert result['passphrase'] == "test phrase four words"
assert result['pin'] == "123456" assert result['pin'] == "123456"
assert 'day_phrase' not in result # Old key should not be present assert 'day_phrase' not in result # Old key should not be present
def test_passphrase_is_string(self): def test_passphrase_is_string(self):
"""Passphrase should be a string, not a dict.""" """Passphrase should be a string, not a dict."""
creds = BatchCredentials( creds = BatchCredentials(
@@ -151,59 +152,59 @@ class TestBatchCredentials:
class TestBatchProcessor: class TestBatchProcessor:
"""Tests for BatchProcessor class.""" """Tests for BatchProcessor class."""
def test_init_default_workers(self): def test_init_default_workers(self):
"""Should default to 4 workers.""" """Should default to 4 workers."""
processor = BatchProcessor() processor = BatchProcessor()
assert processor.max_workers == 4 assert processor.max_workers == 4
def test_init_custom_workers(self): def test_init_custom_workers(self):
"""Should accept custom worker count.""" """Should accept custom worker count."""
processor = BatchProcessor(max_workers=8) processor = BatchProcessor(max_workers=8)
assert processor.max_workers == 8 assert processor.max_workers == 8
def test_is_valid_image_png(self, temp_dir): def test_is_valid_image_png(self, temp_dir):
"""Should recognize PNG as valid.""" """Should recognize PNG as valid."""
processor = BatchProcessor() processor = BatchProcessor()
png_path = temp_dir / "test.png" png_path = temp_dir / "test.png"
png_path.touch() png_path.touch()
assert processor._is_valid_image(png_path) assert processor._is_valid_image(png_path)
def test_is_valid_image_txt(self, temp_dir): def test_is_valid_image_txt(self, temp_dir):
"""Should reject non-image files.""" """Should reject non-image files."""
processor = BatchProcessor() processor = BatchProcessor()
txt_path = temp_dir / "test.txt" txt_path = temp_dir / "test.txt"
txt_path.touch() txt_path.touch()
assert not processor._is_valid_image(txt_path) assert not processor._is_valid_image(txt_path)
def test_find_images_file(self, sample_images): def test_find_images_file(self, sample_images):
"""Should find single image file.""" """Should find single image file."""
processor = BatchProcessor() processor = BatchProcessor()
results = list(processor.find_images([sample_images[0]])) results = list(processor.find_images([sample_images[0]]))
assert len(results) == 1 assert len(results) == 1
assert results[0] == sample_images[0] assert results[0] == sample_images[0]
def test_find_images_directory(self, sample_images, temp_dir): def test_find_images_directory(self, sample_images, temp_dir):
"""Should find images in directory.""" """Should find images in directory."""
processor = BatchProcessor() processor = BatchProcessor()
results = list(processor.find_images([temp_dir])) results = list(processor.find_images([temp_dir]))
assert len(results) == 3 assert len(results) == 3
def test_find_images_recursive(self, temp_dir): def test_find_images_recursive(self, temp_dir):
"""Should find images recursively.""" """Should find images recursively."""
from PIL import Image from PIL import Image
# Create nested directory # Create nested directory
nested = temp_dir / "nested" nested = temp_dir / "nested"
nested.mkdir() nested.mkdir()
img_path = nested / "nested.png" img_path = nested / "nested.png"
img = Image.new('RGB', (50, 50)) img = Image.new('RGB', (50, 50))
img.save(img_path) img.save(img_path)
processor = BatchProcessor() processor = BatchProcessor()
results = list(processor.find_images([temp_dir], recursive=True)) results = list(processor.find_images([temp_dir], recursive=True))
assert any(p.name == "nested.png" for p in results) assert any(p.name == "nested.png" for p in results)
def test_batch_encode_requires_message_or_file(self, sample_images, sample_credentials): def test_batch_encode_requires_message_or_file(self, sample_images, sample_credentials):
"""Should raise if neither message nor file provided.""" """Should raise if neither message nor file provided."""
processor = BatchProcessor() processor = BatchProcessor()
@@ -212,7 +213,7 @@ class TestBatchProcessor:
images=sample_images, images=sample_images,
credentials=sample_credentials, credentials=sample_credentials,
) )
def test_batch_encode_requires_credentials(self, sample_images): def test_batch_encode_requires_credentials(self, sample_images):
"""Should raise if credentials not provided.""" """Should raise if credentials not provided."""
processor = BatchProcessor() processor = BatchProcessor()
@@ -221,7 +222,7 @@ class TestBatchProcessor:
images=sample_images, images=sample_images,
message="test", message="test",
) )
def test_batch_encode_accepts_passphrase_credentials(self, sample_images, temp_dir, sample_credentials): def test_batch_encode_accepts_passphrase_credentials(self, sample_images, temp_dir, sample_credentials):
"""Should accept v3.2.0 format credentials with passphrase.""" """Should accept v3.2.0 format credentials with passphrase."""
processor = BatchProcessor() processor = BatchProcessor()
@@ -231,11 +232,11 @@ class TestBatchProcessor:
output_dir=temp_dir / "output", output_dir=temp_dir / "output",
credentials=sample_credentials, # Uses 'passphrase' key credentials=sample_credentials, # Uses 'passphrase' key
) )
assert isinstance(result, BatchResult) assert isinstance(result, BatchResult)
assert result.operation == "encode" assert result.operation == "encode"
assert result.total == 3 assert result.total == 3
def test_batch_encode_creates_result(self, sample_images, temp_dir, sample_credentials): def test_batch_encode_creates_result(self, sample_images, temp_dir, sample_credentials):
"""Should return BatchResult with correct structure.""" """Should return BatchResult with correct structure."""
processor = BatchProcessor() processor = BatchProcessor()
@@ -245,18 +246,18 @@ class TestBatchProcessor:
output_dir=temp_dir / "output", output_dir=temp_dir / "output",
credentials=sample_credentials, credentials=sample_credentials,
) )
assert isinstance(result, BatchResult) assert isinstance(result, BatchResult)
assert result.operation == "encode" assert result.operation == "encode"
assert result.total == 3 assert result.total == 3
assert len(result.items) == 3 assert len(result.items) == 3
def test_batch_decode_requires_credentials(self, sample_images): def test_batch_decode_requires_credentials(self, sample_images):
"""Should raise if credentials not provided.""" """Should raise if credentials not provided."""
processor = BatchProcessor() processor = BatchProcessor()
with pytest.raises(ValueError, match="Credentials"): with pytest.raises(ValueError, match="Credentials"):
processor.batch_decode(images=sample_images) processor.batch_decode(images=sample_images)
def test_batch_decode_accepts_passphrase_credentials(self, sample_images, sample_credentials): def test_batch_decode_accepts_passphrase_credentials(self, sample_images, sample_credentials):
"""Should accept v3.2.0 format credentials with passphrase.""" """Should accept v3.2.0 format credentials with passphrase."""
processor = BatchProcessor() processor = BatchProcessor()
@@ -264,11 +265,11 @@ class TestBatchProcessor:
images=sample_images, images=sample_images,
credentials=sample_credentials, # Uses 'passphrase' key credentials=sample_credentials, # Uses 'passphrase' key
) )
assert isinstance(result, BatchResult) assert isinstance(result, BatchResult)
assert result.operation == "decode" assert result.operation == "decode"
assert result.total == 3 assert result.total == 3
def test_batch_decode_creates_result(self, sample_images, sample_credentials): def test_batch_decode_creates_result(self, sample_images, sample_credentials):
"""Should return BatchResult with correct structure.""" """Should return BatchResult with correct structure."""
processor = BatchProcessor() processor = BatchProcessor()
@@ -276,30 +277,30 @@ class TestBatchProcessor:
images=sample_images, images=sample_images,
credentials=sample_credentials, credentials=sample_credentials,
) )
assert isinstance(result, BatchResult) assert isinstance(result, BatchResult)
assert result.operation == "decode" assert result.operation == "decode"
assert result.total == 3 assert result.total == 3
def test_progress_callback_called(self, sample_images, sample_credentials): def test_progress_callback_called(self, sample_images, sample_credentials):
"""Progress callback should be called for each item.""" """Progress callback should be called for each item."""
processor = BatchProcessor() processor = BatchProcessor()
callback = Mock() callback = Mock()
processor.batch_encode( processor.batch_encode(
images=sample_images, images=sample_images,
message="Test", message="Test",
credentials=sample_credentials, credentials=sample_credentials,
progress_callback=callback, progress_callback=callback,
) )
assert callback.call_count == 3 assert callback.call_count == 3
def test_custom_encode_func(self, sample_images, temp_dir, sample_credentials): def test_custom_encode_func(self, sample_images, temp_dir, sample_credentials):
"""Should use custom encode function if provided.""" """Should use custom encode function if provided."""
processor = BatchProcessor() processor = BatchProcessor()
encode_mock = Mock() encode_mock = Mock()
processor.batch_encode( processor.batch_encode(
images=sample_images, images=sample_images,
message="Test", message="Test",
@@ -307,19 +308,19 @@ class TestBatchProcessor:
credentials=sample_credentials, credentials=sample_credentials,
encode_func=encode_mock, encode_func=encode_mock,
) )
assert encode_mock.call_count == 3 assert encode_mock.call_count == 3
class TestBatchCapacityCheck: class TestBatchCapacityCheck:
"""Tests for batch_capacity_check function.""" """Tests for batch_capacity_check function."""
def test_returns_list(self, sample_images): def test_returns_list(self, sample_images):
"""Should return list of results.""" """Should return list of results."""
results = batch_capacity_check(sample_images) results = batch_capacity_check(sample_images)
assert isinstance(results, list) assert isinstance(results, list)
assert len(results) == 3 assert len(results) == 3
def test_includes_capacity(self, sample_images): def test_includes_capacity(self, sample_images):
"""Results should include capacity info.""" """Results should include capacity info."""
results = batch_capacity_check(sample_images) results = batch_capacity_check(sample_images)
@@ -327,12 +328,12 @@ class TestBatchCapacityCheck:
assert 'capacity_bytes' in item assert 'capacity_bytes' in item
assert 'dimensions' in item assert 'dimensions' in item
assert 'valid' in item assert 'valid' in item
def test_handles_invalid_files(self, temp_dir): def test_handles_invalid_files(self, temp_dir):
"""Should handle non-image files gracefully.""" """Should handle non-image files gracefully."""
bad_file = temp_dir / "not_an_image.png" bad_file = temp_dir / "not_an_image.png"
bad_file.write_bytes(b"not a png") bad_file.write_bytes(b"not a png")
results = batch_capacity_check([bad_file]) results = batch_capacity_check([bad_file])
assert len(results) == 1 assert len(results) == 1
assert 'error' in results[0] assert 'error' in results[0]
@@ -340,7 +341,7 @@ class TestBatchCapacityCheck:
class TestPrintBatchResult: class TestPrintBatchResult:
"""Tests for print_batch_result function.""" """Tests for print_batch_result function."""
def test_prints_summary(self, capsys, sample_images): def test_prints_summary(self, capsys, sample_images):
"""Should print summary without errors.""" """Should print summary without errors."""
result = BatchResult( result = BatchResult(
@@ -350,14 +351,14 @@ class TestPrintBatchResult:
failed=1, failed=1,
) )
result.end_time = result.start_time + 5.0 result.end_time = result.start_time + 5.0
print_batch_result(result) print_batch_result(result)
captured = capsys.readouterr() captured = capsys.readouterr()
assert "ENCODE" in captured.out assert "ENCODE" in captured.out
assert "3" in captured.out # total assert "3" in captured.out # total
assert "2" in captured.out # succeeded assert "2" in captured.out # succeeded
def test_verbose_shows_items(self, capsys): def test_verbose_shows_items(self, capsys):
"""Verbose mode should show individual items.""" """Verbose mode should show individual items."""
result = BatchResult(operation="decode", total=1, succeeded=1) result = BatchResult(operation="decode", total=1, succeeded=1)
@@ -369,16 +370,16 @@ class TestPrintBatchResult:
) )
] ]
result.end_time = result.start_time + 1.0 result.end_time = result.start_time + 1.0
print_batch_result(result, verbose=True) print_batch_result(result, verbose=True)
captured = capsys.readouterr() captured = capsys.readouterr()
assert "test.png" in captured.out assert "test.png" in captured.out
class TestCredentialsMigration: class TestCredentialsMigration:
"""Tests for v3.1.x to v3.2.0 credentials migration.""" """Tests for v3.1.x to v3.2.0 credentials migration."""
def test_old_phrase_key_accepted(self): def test_old_phrase_key_accepted(self):
"""Old 'phrase' key should be accepted for migration.""" """Old 'phrase' key should be accepted for migration."""
old_format = { old_format = {
@@ -388,7 +389,7 @@ class TestCredentialsMigration:
# Should not raise # Should not raise
creds = BatchCredentials.from_dict(old_format) creds = BatchCredentials.from_dict(old_format)
assert creds.passphrase == "old style phrase" assert creds.passphrase == "old style phrase"
def test_old_day_phrase_key_accepted(self): def test_old_day_phrase_key_accepted(self):
"""Old 'day_phrase' key should be accepted for migration.""" """Old 'day_phrase' key should be accepted for migration."""
old_format = { old_format = {
@@ -397,7 +398,7 @@ class TestCredentialsMigration:
} }
creds = BatchCredentials.from_dict(old_format) creds = BatchCredentials.from_dict(old_format)
assert creds.passphrase == "old day phrase" assert creds.passphrase == "old day phrase"
def test_new_passphrase_key_preferred(self): def test_new_passphrase_key_preferred(self):
"""New 'passphrase' key should take precedence if both present.""" """New 'passphrase' key should take precedence if both present."""
mixed_format = { mixed_format = {

View File

@@ -3,24 +3,25 @@ Tests for Stegasoo compression module.
""" """
import pytest import pytest
from stegasoo.compression import ( from stegasoo.compression import (
compress,
decompress,
CompressionAlgorithm,
CompressionError,
get_compression_ratio,
estimate_compressed_size,
get_available_algorithms,
algorithm_name,
MIN_COMPRESS_SIZE,
COMPRESSION_MAGIC, COMPRESSION_MAGIC,
HAS_LZ4, HAS_LZ4,
MIN_COMPRESS_SIZE,
CompressionAlgorithm,
CompressionError,
algorithm_name,
compress,
decompress,
estimate_compressed_size,
get_available_algorithms,
get_compression_ratio,
) )
class TestCompress: class TestCompress:
"""Tests for compress function.""" """Tests for compress function."""
def test_compress_small_data_not_compressed(self): def test_compress_small_data_not_compressed(self):
"""Small data should not be compressed (overhead not worth it).""" """Small data should not be compressed (overhead not worth it)."""
small_data = b"hello" small_data = b"hello"
@@ -28,7 +29,7 @@ class TestCompress:
# Should have magic header but NONE algorithm # Should have magic header but NONE algorithm
assert result.startswith(COMPRESSION_MAGIC) assert result.startswith(COMPRESSION_MAGIC)
assert result[4] == CompressionAlgorithm.NONE assert result[4] == CompressionAlgorithm.NONE
def test_compress_zlib_reduces_size(self): def test_compress_zlib_reduces_size(self):
"""Zlib should reduce size for compressible data.""" """Zlib should reduce size for compressible data."""
# Highly compressible data # Highly compressible data
@@ -37,7 +38,7 @@ class TestCompress:
assert len(result) < len(data) assert len(result) < len(data)
assert result.startswith(COMPRESSION_MAGIC) assert result.startswith(COMPRESSION_MAGIC)
assert result[4] == CompressionAlgorithm.ZLIB assert result[4] == CompressionAlgorithm.ZLIB
def test_compress_incompressible_data(self): def test_compress_incompressible_data(self):
"""Incompressible data should be stored uncompressed.""" """Incompressible data should be stored uncompressed."""
import os import os
@@ -46,7 +47,7 @@ class TestCompress:
result = compress(data, CompressionAlgorithm.ZLIB) result = compress(data, CompressionAlgorithm.ZLIB)
# Should fall back to NONE if compression didn't help # Should fall back to NONE if compression didn't help
assert result.startswith(COMPRESSION_MAGIC) assert result.startswith(COMPRESSION_MAGIC)
def test_compress_none_algorithm(self): def test_compress_none_algorithm(self):
"""NONE algorithm should just wrap data.""" """NONE algorithm should just wrap data."""
data = b"Test data" * 100 data = b"Test data" * 100
@@ -55,7 +56,7 @@ class TestCompress:
assert result[4] == CompressionAlgorithm.NONE assert result[4] == CompressionAlgorithm.NONE
# Data should be after 9-byte header # Data should be after 9-byte header
assert result[9:] == data assert result[9:] == data
@pytest.mark.skipif(not HAS_LZ4, reason="LZ4 not installed") @pytest.mark.skipif(not HAS_LZ4, reason="LZ4 not installed")
def test_compress_lz4(self): def test_compress_lz4(self):
"""LZ4 compression should work if available.""" """LZ4 compression should work if available."""
@@ -68,33 +69,33 @@ class TestCompress:
class TestDecompress: class TestDecompress:
"""Tests for decompress function.""" """Tests for decompress function."""
def test_decompress_zlib(self): def test_decompress_zlib(self):
"""Decompression should restore original data.""" """Decompression should restore original data."""
original = b"Hello, World! " * 100 original = b"Hello, World! " * 100
compressed = compress(original, CompressionAlgorithm.ZLIB) compressed = compress(original, CompressionAlgorithm.ZLIB)
result = decompress(compressed) result = decompress(compressed)
assert result == original assert result == original
def test_decompress_none(self): def test_decompress_none(self):
"""Uncompressed wrapped data should decompress correctly.""" """Uncompressed wrapped data should decompress correctly."""
original = b"Small data" original = b"Small data"
wrapped = compress(original, CompressionAlgorithm.NONE) wrapped = compress(original, CompressionAlgorithm.NONE)
result = decompress(wrapped) result = decompress(wrapped)
assert result == original assert result == original
def test_decompress_no_magic(self): def test_decompress_no_magic(self):
"""Data without magic header should be returned as-is.""" """Data without magic header should be returned as-is."""
data = b"Not compressed at all" data = b"Not compressed at all"
result = decompress(data) result = decompress(data)
assert result == data assert result == data
def test_decompress_truncated_header(self): def test_decompress_truncated_header(self):
"""Truncated header should raise CompressionError.""" """Truncated header should raise CompressionError."""
bad_data = COMPRESSION_MAGIC + b"\x01" # Too short bad_data = COMPRESSION_MAGIC + b"\x01" # Too short
with pytest.raises(CompressionError, match="Truncated"): with pytest.raises(CompressionError, match="Truncated"):
decompress(bad_data) decompress(bad_data)
@pytest.mark.skipif(not HAS_LZ4, reason="LZ4 not installed") @pytest.mark.skipif(not HAS_LZ4, reason="LZ4 not installed")
def test_decompress_lz4(self): def test_decompress_lz4(self):
"""LZ4 decompression should work.""" """LZ4 decompression should work."""
@@ -102,7 +103,7 @@ class TestDecompress:
compressed = compress(original, CompressionAlgorithm.LZ4) compressed = compress(original, CompressionAlgorithm.LZ4)
result = decompress(compressed) result = decompress(compressed)
assert result == original assert result == original
def test_roundtrip_large_data(self): def test_roundtrip_large_data(self):
"""Large data should survive compress/decompress roundtrip.""" """Large data should survive compress/decompress roundtrip."""
import os import os
@@ -114,19 +115,19 @@ class TestDecompress:
class TestUtilities: class TestUtilities:
"""Tests for utility functions.""" """Tests for utility functions."""
def test_compression_ratio_compressed(self): def test_compression_ratio_compressed(self):
"""Ratio should be < 1 for well-compressed data.""" """Ratio should be < 1 for well-compressed data."""
original = b"X" * 1000 original = b"X" * 1000
compressed = compress(original) compressed = compress(original)
ratio = get_compression_ratio(original, compressed) ratio = get_compression_ratio(original, compressed)
assert ratio < 1.0 assert ratio < 1.0
def test_compression_ratio_empty(self): def test_compression_ratio_empty(self):
"""Empty data should return ratio of 1.0.""" """Empty data should return ratio of 1.0."""
ratio = get_compression_ratio(b"", b"") ratio = get_compression_ratio(b"", b"")
assert ratio == 1.0 assert ratio == 1.0
def test_estimate_compressed_size_small(self): def test_estimate_compressed_size_small(self):
"""Small data estimation should be accurate.""" """Small data estimation should be accurate."""
data = b"Test " * 100 data = b"Test " * 100
@@ -134,13 +135,13 @@ class TestUtilities:
actual = len(compress(data)) actual = len(compress(data))
# Should be within 20% for small data # Should be within 20% for small data
assert abs(estimate - actual) / actual < 0.2 assert abs(estimate - actual) / actual < 0.2
def test_available_algorithms(self): def test_available_algorithms(self):
"""Should always include NONE and ZLIB.""" """Should always include NONE and ZLIB."""
algos = get_available_algorithms() algos = get_available_algorithms()
assert CompressionAlgorithm.NONE in algos assert CompressionAlgorithm.NONE in algos
assert CompressionAlgorithm.ZLIB in algos assert CompressionAlgorithm.ZLIB in algos
def test_algorithm_name(self): def test_algorithm_name(self):
"""Algorithm names should be human-readable.""" """Algorithm names should be human-readable."""
assert "Zlib" in algorithm_name(CompressionAlgorithm.ZLIB) assert "Zlib" in algorithm_name(CompressionAlgorithm.ZLIB)
@@ -150,25 +151,25 @@ class TestUtilities:
class TestEdgeCases: class TestEdgeCases:
"""Edge case tests.""" """Edge case tests."""
def test_empty_data(self): def test_empty_data(self):
"""Empty data should be handled gracefully.""" """Empty data should be handled gracefully."""
result = compress(b"") result = compress(b"")
assert decompress(result) == b"" assert decompress(result) == b""
def test_exact_min_size(self): def test_exact_min_size(self):
"""Data at exactly MIN_COMPRESS_SIZE should be compressed.""" """Data at exactly MIN_COMPRESS_SIZE should be compressed."""
data = b"x" * MIN_COMPRESS_SIZE data = b"x" * MIN_COMPRESS_SIZE
result = compress(data, CompressionAlgorithm.ZLIB) result = compress(data, CompressionAlgorithm.ZLIB)
assert result.startswith(COMPRESSION_MAGIC) assert result.startswith(COMPRESSION_MAGIC)
assert decompress(result) == data assert decompress(result) == data
def test_binary_data(self): def test_binary_data(self):
"""Binary data with null bytes should work.""" """Binary data with null bytes should work."""
data = b"\x00\x01\x02\x03" * 500 data = b"\x00\x01\x02\x03" * 500
compressed = compress(data) compressed = compress(data)
assert decompress(compressed) == data assert decompress(compressed) == data
def test_unicode_after_encoding(self): def test_unicode_after_encoding(self):
"""UTF-8 encoded Unicode should compress correctly.""" """UTF-8 encoded Unicode should compress correctly."""
text = "Hello, 世界! 🎉 " * 100 text = "Hello, 世界! 🎉 " * 100

View File

@@ -11,29 +11,28 @@ Updated for v4.0.0:
- Python 3.12 recommended (3.13 not supported) - Python 3.12 recommended (3.13 not supported)
""" """
import io
import pytest import pytest
from PIL import Image from PIL import Image
import io
import stegasoo import stegasoo
from stegasoo import ( from stegasoo import (
generate_pin,
generate_passphrase,
generate_credentials,
validate_pin,
validate_message,
validate_passphrase,
validate_channel_key,
encode,
decode, decode,
decode_text, decode_text,
encode,
generate_channel_key, generate_channel_key,
generate_credentials,
generate_passphrase,
generate_pin,
get_channel_fingerprint, get_channel_fingerprint,
__version__, validate_channel_key,
validate_message,
validate_passphrase,
validate_pin,
) )
from stegasoo.steganography import get_output_format from stegasoo.steganography import get_output_format
# ============================================================================= # =============================================================================
# Fixtures # Fixtures
# ============================================================================= # =============================================================================
@@ -94,7 +93,7 @@ def gif_image():
class TestKeygen: class TestKeygen:
"""Tests for key generation functions.""" """Tests for key generation functions."""
def test_generate_pin_default(self): def test_generate_pin_default(self):
"""Default PIN should be 6 digits, no leading zero.""" """Default PIN should be 6 digits, no leading zero."""
pin = generate_pin() pin = generate_pin()
@@ -183,7 +182,7 @@ class TestKeygen:
class TestValidation: class TestValidation:
"""Tests for validation functions.""" """Tests for validation functions."""
def test_validate_pin_valid(self): def test_validate_pin_valid(self):
"""Valid PIN should pass validation.""" """Valid PIN should pass validation."""
result = validate_pin("123456") result = validate_pin("123456")
@@ -253,7 +252,7 @@ class TestValidation:
class TestOutputFormat: class TestOutputFormat:
"""Tests for output format handling.""" """Tests for output format handling."""
def test_png_stays_png(self): def test_png_stays_png(self):
"""PNG input should produce PNG output.""" """PNG input should produce PNG output."""
fmt, ext = get_output_format('PNG') fmt, ext = get_output_format('PNG')
@@ -310,7 +309,7 @@ class TestConstants:
class TestEncodeDecode: class TestEncodeDecode:
"""Tests for encoding and decoding functions.""" """Tests for encoding and decoding functions."""
def test_encode_decode_roundtrip(self, png_image): def test_encode_decode_roundtrip(self, png_image):
"""Full encode/decode cycle should work.""" """Full encode/decode cycle should work."""
message = "Secret message!" message = "Secret message!"
@@ -501,7 +500,7 @@ class TestEncodeDecode:
class TestDCTMode: class TestDCTMode:
"""Tests for DCT steganography mode.""" """Tests for DCT steganography mode."""
@pytest.fixture @pytest.fixture
def skip_if_no_dct(self): def skip_if_no_dct(self):
"""Skip test if DCT support not available.""" """Skip test if DCT support not available."""
@@ -567,7 +566,7 @@ class TestDCTMode:
class TestVersion: class TestVersion:
"""Tests for version information.""" """Tests for version information."""
def test_version_exists(self): def test_version_exists(self):
"""Version string should exist and be valid.""" """Version string should exist and be valid."""
assert hasattr(stegasoo, '__version__') assert hasattr(stegasoo, '__version__')
@@ -588,7 +587,7 @@ class TestVersion:
class TestBackwardCompatibility: class TestBackwardCompatibility:
"""Tests for backward compatibility handling.""" """Tests for backward compatibility handling."""
def test_old_day_phrase_parameter_raises(self, png_image): def test_old_day_phrase_parameter_raises(self, png_image):
"""Using old day_phrase parameter should raise TypeError.""" """Using old day_phrase parameter should raise TypeError."""
with pytest.raises(TypeError): with pytest.raises(TypeError):