Lint cleanup: ruff fixes across entire codebase
- Strip trailing whitespace from all Python files - Fix import sorting (I001) across all modules - Convert Optional[X] to X | None syntax (UP045) - Remove unused imports (F401) - Convert lambda assignments to def functions (E731) - Add TYPE_CHECKING import for forward references - Update pyproject.toml ruff config: - Move select/ignore to [tool.ruff.lint] section - Add per-file ignores for DCT colorspace naming (N803/N806) - Add per-file ignores for __init__.py import structure (E402) - Exclude defunct test_routes.py - Remove frontends/web/test_routes.py (defunct debug snippet) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import sys
|
||||
def main():
|
||||
"""
|
||||
Main entry point for Stegasoo CLI.
|
||||
|
||||
|
||||
Delegates to the CLI module for command parsing and execution.
|
||||
"""
|
||||
try:
|
||||
|
||||
@@ -10,56 +10,55 @@ Changes in v4.0.0:
|
||||
__version__ = "4.0.1"
|
||||
|
||||
# Core functionality
|
||||
from .encode import encode
|
||||
# Channel key management (v4.0.0)
|
||||
from .channel import (
|
||||
clear_channel_key,
|
||||
format_channel_key,
|
||||
generate_channel_key,
|
||||
get_channel_key,
|
||||
get_channel_status,
|
||||
has_channel_key,
|
||||
set_channel_key,
|
||||
validate_channel_key,
|
||||
)
|
||||
|
||||
# Crypto functions
|
||||
from .crypto import get_active_channel_key, get_channel_fingerprint, has_argon2
|
||||
from .decode import decode, decode_file, decode_text
|
||||
from .encode import encode
|
||||
|
||||
# Credential generation
|
||||
from .generate import (
|
||||
generate_pin,
|
||||
generate_passphrase,
|
||||
generate_rsa_key,
|
||||
generate_credentials,
|
||||
export_rsa_key_pem,
|
||||
generate_credentials,
|
||||
generate_passphrase,
|
||||
generate_pin,
|
||||
generate_rsa_key,
|
||||
load_rsa_key,
|
||||
)
|
||||
|
||||
# Image utilities
|
||||
from .image_utils import (
|
||||
get_image_info,
|
||||
compare_capacity,
|
||||
get_image_info,
|
||||
)
|
||||
|
||||
# Steganography functions
|
||||
from .steganography import (
|
||||
compare_modes,
|
||||
has_dct_support,
|
||||
will_fit_by_mode,
|
||||
)
|
||||
|
||||
# Utilities
|
||||
from .utils import generate_filename
|
||||
|
||||
# Crypto functions
|
||||
from .crypto import has_argon2, get_active_channel_key, get_channel_fingerprint
|
||||
|
||||
# Channel key management (v4.0.0)
|
||||
from .channel import (
|
||||
generate_channel_key,
|
||||
get_channel_key,
|
||||
set_channel_key,
|
||||
clear_channel_key,
|
||||
has_channel_key,
|
||||
get_channel_status,
|
||||
validate_channel_key,
|
||||
format_channel_key,
|
||||
)
|
||||
|
||||
# Steganography functions
|
||||
from .steganography import (
|
||||
has_dct_support,
|
||||
compare_modes,
|
||||
will_fit_by_mode,
|
||||
)
|
||||
|
||||
# QR Code utilities - optional, may not be available
|
||||
try:
|
||||
from .qr_utils import (
|
||||
generate_qr_code,
|
||||
extract_key_from_qr,
|
||||
detect_and_crop_qr,
|
||||
extract_key_from_qr,
|
||||
generate_qr_code,
|
||||
)
|
||||
HAS_QR_UTILS = True
|
||||
except ImportError:
|
||||
@@ -70,12 +69,12 @@ except ImportError:
|
||||
|
||||
# Validation
|
||||
from .validation import (
|
||||
validate_file_payload,
|
||||
validate_image,
|
||||
validate_message,
|
||||
validate_passphrase,
|
||||
validate_pin,
|
||||
validate_rsa_key,
|
||||
validate_message,
|
||||
validate_file_payload,
|
||||
validate_image,
|
||||
validate_security_factors,
|
||||
)
|
||||
|
||||
@@ -84,62 +83,61 @@ validate_reference_photo = validate_image
|
||||
validate_carrier = validate_image
|
||||
|
||||
# Additional validators
|
||||
from .validation import (
|
||||
validate_embed_mode,
|
||||
validate_dct_output_format,
|
||||
validate_dct_color_mode,
|
||||
)
|
||||
|
||||
# Models
|
||||
from .models import (
|
||||
ImageInfo,
|
||||
CapacityComparison,
|
||||
GenerateResult,
|
||||
EncodeResult,
|
||||
DecodeResult,
|
||||
FilePayload,
|
||||
Credentials,
|
||||
ValidationResult,
|
||||
# Constants
|
||||
from .constants import (
|
||||
DEFAULT_PASSPHRASE_WORDS,
|
||||
EMBED_MODE_AUTO,
|
||||
EMBED_MODE_DCT,
|
||||
EMBED_MODE_LSB,
|
||||
FORMAT_VERSION,
|
||||
LOSSLESS_FORMATS,
|
||||
MAX_IMAGE_PIXELS,
|
||||
MAX_MESSAGE_SIZE,
|
||||
MAX_PASSPHRASE_WORDS,
|
||||
MAX_PIN_LENGTH,
|
||||
MIN_IMAGE_PIXELS,
|
||||
MIN_PASSPHRASE_WORDS,
|
||||
MIN_PIN_LENGTH,
|
||||
RECOMMENDED_PASSPHRASE_WORDS,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
StegasooError,
|
||||
ValidationError,
|
||||
PinValidationError,
|
||||
MessageValidationError,
|
||||
ImageValidationError,
|
||||
KeyValidationError,
|
||||
SecurityFactorError,
|
||||
CapacityError,
|
||||
CryptoError,
|
||||
EncryptionError,
|
||||
DecryptionError,
|
||||
EmbeddingError,
|
||||
EncryptionError,
|
||||
ExtractionError,
|
||||
ImageValidationError,
|
||||
InvalidHeaderError,
|
||||
KeyDerivationError,
|
||||
KeyGenerationError,
|
||||
KeyPasswordError,
|
||||
KeyValidationError,
|
||||
MessageValidationError,
|
||||
PinValidationError,
|
||||
SecurityFactorError,
|
||||
SteganographyError,
|
||||
CapacityError,
|
||||
ExtractionError,
|
||||
EmbeddingError,
|
||||
InvalidHeaderError,
|
||||
StegasooError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
# Constants
|
||||
from .constants import (
|
||||
FORMAT_VERSION,
|
||||
MIN_PASSPHRASE_WORDS,
|
||||
RECOMMENDED_PASSPHRASE_WORDS,
|
||||
DEFAULT_PASSPHRASE_WORDS,
|
||||
MAX_PASSPHRASE_WORDS,
|
||||
MIN_PIN_LENGTH,
|
||||
MAX_PIN_LENGTH,
|
||||
MAX_MESSAGE_SIZE,
|
||||
MIN_IMAGE_PIXELS,
|
||||
MAX_IMAGE_PIXELS,
|
||||
LOSSLESS_FORMATS,
|
||||
EMBED_MODE_LSB,
|
||||
EMBED_MODE_DCT,
|
||||
EMBED_MODE_AUTO,
|
||||
# Models
|
||||
from .models import (
|
||||
CapacityComparison,
|
||||
Credentials,
|
||||
DecodeResult,
|
||||
EncodeResult,
|
||||
FilePayload,
|
||||
GenerateResult,
|
||||
ImageInfo,
|
||||
ValidationResult,
|
||||
)
|
||||
from .validation import (
|
||||
validate_dct_color_mode,
|
||||
validate_dct_output_format,
|
||||
validate_embed_mode,
|
||||
)
|
||||
|
||||
# Aliases for backward compatibility
|
||||
@@ -159,7 +157,7 @@ __all__ = [
|
||||
"decode",
|
||||
"decode_file",
|
||||
"decode_text",
|
||||
|
||||
|
||||
# Generation
|
||||
"generate_pin",
|
||||
"generate_passphrase",
|
||||
@@ -167,7 +165,7 @@ __all__ = [
|
||||
"generate_credentials",
|
||||
"export_rsa_key_pem",
|
||||
"load_rsa_key",
|
||||
|
||||
|
||||
# Channel key management (v4.0.0)
|
||||
"generate_channel_key",
|
||||
"get_channel_key",
|
||||
@@ -179,28 +177,28 @@ __all__ = [
|
||||
"format_channel_key",
|
||||
"get_active_channel_key",
|
||||
"get_channel_fingerprint",
|
||||
|
||||
|
||||
# Image utilities
|
||||
"get_image_info",
|
||||
"compare_capacity",
|
||||
|
||||
|
||||
# Utilities
|
||||
"generate_filename",
|
||||
|
||||
|
||||
# Crypto
|
||||
"has_argon2",
|
||||
|
||||
|
||||
# Steganography
|
||||
"has_dct_support",
|
||||
"compare_modes",
|
||||
"will_fit_by_mode",
|
||||
|
||||
|
||||
# QR utilities
|
||||
"generate_qr_code",
|
||||
"extract_key_from_qr",
|
||||
"detect_and_crop_qr",
|
||||
"HAS_QR_UTILS",
|
||||
|
||||
|
||||
# Validation
|
||||
"validate_reference_photo",
|
||||
"validate_carrier",
|
||||
@@ -214,7 +212,7 @@ __all__ = [
|
||||
"validate_dct_output_format",
|
||||
"validate_dct_color_mode",
|
||||
"validate_channel_key",
|
||||
|
||||
|
||||
# Models
|
||||
"ImageInfo",
|
||||
"CapacityComparison",
|
||||
@@ -224,7 +222,7 @@ __all__ = [
|
||||
"FilePayload",
|
||||
"Credentials",
|
||||
"ValidationResult",
|
||||
|
||||
|
||||
# Exceptions
|
||||
"StegasooError",
|
||||
"ValidationError",
|
||||
@@ -244,7 +242,7 @@ __all__ = [
|
||||
"ExtractionError",
|
||||
"EmbeddingError",
|
||||
"InvalidHeaderError",
|
||||
|
||||
|
||||
# Constants
|
||||
"FORMAT_VERSION",
|
||||
"MIN_PASSPHRASE_WORDS",
|
||||
|
||||
@@ -9,15 +9,14 @@ Changes in v3.2.0:
|
||||
- Updated all credential handling to use v3.2.0 API
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Optional, Callable, Iterator
|
||||
from enum import Enum
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from .constants import ALLOWED_IMAGE_EXTENSIONS, LOSSLESS_FORMATS
|
||||
|
||||
@@ -35,22 +34,22 @@ class BatchStatus(Enum):
|
||||
class BatchItem:
|
||||
"""Represents a single item in a batch operation."""
|
||||
input_path: Path
|
||||
output_path: Optional[Path] = None
|
||||
output_path: Path | None = None
|
||||
status: BatchStatus = BatchStatus.PENDING
|
||||
error: Optional[str] = None
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
error: str | None = None
|
||||
start_time: float | None = None
|
||||
end_time: float | None = None
|
||||
input_size: int = 0
|
||||
output_size: int = 0
|
||||
message: str = ""
|
||||
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[float]:
|
||||
def duration(self) -> float | None:
|
||||
"""Processing duration in seconds."""
|
||||
if self.start_time and self.end_time:
|
||||
return self.end_time - self.start_time
|
||||
return None
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
@@ -69,14 +68,14 @@ class BatchItem:
|
||||
class BatchCredentials:
|
||||
"""
|
||||
Credentials for batch encode/decode operations (v3.2.0).
|
||||
|
||||
|
||||
Provides a structured way to pass authentication factors
|
||||
for batch processing instead of using plain dicts.
|
||||
|
||||
|
||||
Changes in v3.2.0:
|
||||
- Renamed day_phrase → passphrase
|
||||
- Removed date_str (no longer used in cryptographic operations)
|
||||
|
||||
|
||||
Example:
|
||||
creds = BatchCredentials(
|
||||
reference_photo=ref_bytes,
|
||||
@@ -88,9 +87,9 @@ class BatchCredentials:
|
||||
reference_photo: bytes
|
||||
passphrase: str # v3.2.0: renamed from day_phrase
|
||||
pin: str = ""
|
||||
rsa_key_data: Optional[bytes] = None
|
||||
rsa_password: Optional[str] = None
|
||||
|
||||
rsa_key_data: bytes | None = None
|
||||
rsa_password: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for API compatibility."""
|
||||
return {
|
||||
@@ -100,17 +99,17 @@ class BatchCredentials:
|
||||
"rsa_key_data": self.rsa_key_data,
|
||||
"rsa_password": self.rsa_password,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'BatchCredentials':
|
||||
"""
|
||||
Create BatchCredentials from a dictionary.
|
||||
|
||||
|
||||
Handles both v3.2.0 format (passphrase) and legacy format (day_phrase).
|
||||
"""
|
||||
# Handle legacy 'day_phrase' key
|
||||
passphrase = data.get('passphrase') or data.get('day_phrase', '')
|
||||
|
||||
|
||||
return cls(
|
||||
reference_photo=data['reference_photo'],
|
||||
passphrase=passphrase,
|
||||
@@ -129,16 +128,16 @@ class BatchResult:
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
start_time: float = field(default_factory=time.time)
|
||||
end_time: Optional[float] = None
|
||||
end_time: float | None = None
|
||||
items: list[BatchItem] = field(default_factory=list)
|
||||
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[float]:
|
||||
def duration(self) -> float | None:
|
||||
"""Total batch duration in seconds."""
|
||||
if self.end_time:
|
||||
return self.end_time - self.start_time
|
||||
return None
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
@@ -152,7 +151,7 @@ class BatchResult:
|
||||
},
|
||||
"items": [item.to_dict() for item in self.items],
|
||||
}
|
||||
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""Serialize to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
@@ -165,10 +164,10 @@ ProgressCallback = Callable[[int, int, BatchItem], None]
|
||||
class BatchProcessor:
|
||||
"""
|
||||
Handles batch encoding/decoding operations (v3.2.0).
|
||||
|
||||
|
||||
Usage:
|
||||
processor = BatchProcessor(max_workers=4)
|
||||
|
||||
|
||||
# Batch encode with BatchCredentials
|
||||
creds = BatchCredentials(
|
||||
reference_photo=ref_bytes,
|
||||
@@ -181,7 +180,7 @@ class BatchProcessor:
|
||||
output_dir="./encoded/",
|
||||
credentials=creds,
|
||||
)
|
||||
|
||||
|
||||
# Batch encode with dict credentials
|
||||
result = processor.batch_encode(
|
||||
images=['img1.png', 'img2.png'],
|
||||
@@ -192,24 +191,24 @@ class BatchProcessor:
|
||||
"pin": "123456"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Batch decode
|
||||
result = processor.batch_decode(
|
||||
images=['encoded1.png', 'encoded2.png'],
|
||||
credentials=creds,
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, max_workers: int = 4):
|
||||
"""
|
||||
Initialize batch processor.
|
||||
|
||||
|
||||
Args:
|
||||
max_workers: Maximum parallel workers (default 4)
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self._lock = threading.Lock()
|
||||
|
||||
|
||||
def find_images(
|
||||
self,
|
||||
paths: list[str | Path],
|
||||
@@ -217,67 +216,67 @@ class BatchProcessor:
|
||||
) -> Iterator[Path]:
|
||||
"""
|
||||
Find all valid image files from paths.
|
||||
|
||||
|
||||
Args:
|
||||
paths: List of files or directories
|
||||
recursive: Search directories recursively
|
||||
|
||||
|
||||
Yields:
|
||||
Path objects for each valid image
|
||||
"""
|
||||
for path in paths:
|
||||
path = Path(path)
|
||||
|
||||
|
||||
if path.is_file():
|
||||
if self._is_valid_image(path):
|
||||
yield path
|
||||
|
||||
|
||||
elif path.is_dir():
|
||||
pattern = '**/*' if recursive else '*'
|
||||
for file_path in path.glob(pattern):
|
||||
if file_path.is_file() and self._is_valid_image(file_path):
|
||||
yield file_path
|
||||
|
||||
|
||||
def _is_valid_image(self, path: Path) -> bool:
|
||||
"""Check if path is a valid image file."""
|
||||
return path.suffix.lower().lstrip('.') in ALLOWED_IMAGE_EXTENSIONS
|
||||
|
||||
|
||||
def _normalize_credentials(
|
||||
self,
|
||||
self,
|
||||
credentials: dict | BatchCredentials | None
|
||||
) -> BatchCredentials:
|
||||
"""
|
||||
Normalize credentials to BatchCredentials object.
|
||||
|
||||
|
||||
Handles both dict and BatchCredentials input, and legacy 'day_phrase' key.
|
||||
"""
|
||||
if credentials is None:
|
||||
raise ValueError("Credentials are required")
|
||||
|
||||
|
||||
if isinstance(credentials, BatchCredentials):
|
||||
return credentials
|
||||
|
||||
|
||||
if isinstance(credentials, dict):
|
||||
return BatchCredentials.from_dict(credentials)
|
||||
|
||||
|
||||
raise ValueError(f"Invalid credentials type: {type(credentials)}")
|
||||
|
||||
|
||||
def batch_encode(
|
||||
self,
|
||||
images: list[str | Path],
|
||||
message: Optional[str] = None,
|
||||
file_payload: Optional[Path] = None,
|
||||
output_dir: Optional[Path] = None,
|
||||
message: str | None = None,
|
||||
file_payload: Path | None = None,
|
||||
output_dir: Path | None = None,
|
||||
output_suffix: str = "_encoded",
|
||||
credentials: dict | BatchCredentials | None = None,
|
||||
compress: bool = True,
|
||||
recursive: bool = False,
|
||||
progress_callback: Optional[ProgressCallback] = None,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
encode_func: Callable = None,
|
||||
) -> BatchResult:
|
||||
"""
|
||||
Encode message into multiple images.
|
||||
|
||||
|
||||
Args:
|
||||
images: List of image paths or directories
|
||||
message: Text message to encode (mutually exclusive with file_payload)
|
||||
@@ -289,43 +288,43 @@ class BatchProcessor:
|
||||
recursive: Search directories recursively
|
||||
progress_callback: Called for each item: callback(current, total, item)
|
||||
encode_func: Custom encode function (for integration)
|
||||
|
||||
|
||||
Returns:
|
||||
BatchResult with operation summary
|
||||
"""
|
||||
if message is None and file_payload is None:
|
||||
raise ValueError("Either message or file_payload must be provided")
|
||||
|
||||
|
||||
# Normalize credentials to BatchCredentials
|
||||
creds = self._normalize_credentials(credentials)
|
||||
|
||||
|
||||
result = BatchResult(operation="encode")
|
||||
image_paths = list(self.find_images(images, recursive))
|
||||
result.total = len(image_paths)
|
||||
|
||||
|
||||
if output_dir:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Prepare batch items
|
||||
for img_path in image_paths:
|
||||
if output_dir:
|
||||
out_path = output_dir / f"{img_path.stem}{output_suffix}.png"
|
||||
else:
|
||||
out_path = img_path.parent / f"{img_path.stem}{output_suffix}.png"
|
||||
|
||||
|
||||
item = BatchItem(
|
||||
input_path=img_path,
|
||||
output_path=out_path,
|
||||
input_size=img_path.stat().st_size if img_path.exists() else 0,
|
||||
)
|
||||
result.items.append(item)
|
||||
|
||||
|
||||
# Process items
|
||||
def process_encode(item: BatchItem) -> BatchItem:
|
||||
item.status = BatchStatus.PROCESSING
|
||||
item.start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
if encode_func:
|
||||
# Use provided encode function
|
||||
@@ -340,35 +339,35 @@ class BatchProcessor:
|
||||
else:
|
||||
# Use stegasoo encode
|
||||
self._do_encode(item, message, file_payload, creds, compress)
|
||||
|
||||
|
||||
item.status = BatchStatus.SUCCESS
|
||||
item.output_size = item.output_path.stat().st_size if item.output_path and item.output_path.exists() else 0
|
||||
item.message = f"Encoded to {item.output_path.name}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
item.status = BatchStatus.FAILED
|
||||
item.error = str(e)
|
||||
|
||||
|
||||
item.end_time = time.time()
|
||||
return item
|
||||
|
||||
|
||||
# Execute with thread pool
|
||||
self._execute_batch(result, process_encode, progress_callback)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def batch_decode(
|
||||
self,
|
||||
images: list[str | Path],
|
||||
output_dir: Optional[Path] = None,
|
||||
output_dir: Path | None = None,
|
||||
credentials: dict | BatchCredentials | None = None,
|
||||
recursive: bool = False,
|
||||
progress_callback: Optional[ProgressCallback] = None,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
decode_func: Callable = None,
|
||||
) -> BatchResult:
|
||||
"""
|
||||
Decode messages from multiple images.
|
||||
|
||||
|
||||
Args:
|
||||
images: List of image paths or directories
|
||||
output_dir: Output directory for file payloads (default: same as input)
|
||||
@@ -376,21 +375,21 @@ class BatchProcessor:
|
||||
recursive: Search directories recursively
|
||||
progress_callback: Called for each item: callback(current, total, item)
|
||||
decode_func: Custom decode function (for integration)
|
||||
|
||||
|
||||
Returns:
|
||||
BatchResult with decoded messages in item.message fields
|
||||
"""
|
||||
# Normalize credentials to BatchCredentials
|
||||
creds = self._normalize_credentials(credentials)
|
||||
|
||||
|
||||
result = BatchResult(operation="decode")
|
||||
image_paths = list(self.find_images(images, recursive))
|
||||
result.total = len(image_paths)
|
||||
|
||||
|
||||
if output_dir:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Prepare batch items
|
||||
for img_path in image_paths:
|
||||
item = BatchItem(
|
||||
@@ -399,12 +398,12 @@ class BatchProcessor:
|
||||
input_size=img_path.stat().st_size if img_path.exists() else 0,
|
||||
)
|
||||
result.items.append(item)
|
||||
|
||||
|
||||
# Process items
|
||||
def process_decode(item: BatchItem) -> BatchItem:
|
||||
item.status = BatchStatus.PROCESSING
|
||||
item.start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
if decode_func:
|
||||
# Use provided decode function
|
||||
@@ -417,40 +416,40 @@ class BatchProcessor:
|
||||
else:
|
||||
# Use stegasoo decode
|
||||
item.message = self._do_decode(item, creds)
|
||||
|
||||
|
||||
item.status = BatchStatus.SUCCESS
|
||||
|
||||
|
||||
except Exception as e:
|
||||
item.status = BatchStatus.FAILED
|
||||
item.error = str(e)
|
||||
|
||||
|
||||
item.end_time = time.time()
|
||||
return item
|
||||
|
||||
|
||||
# Execute with thread pool
|
||||
self._execute_batch(result, process_decode, progress_callback)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _execute_batch(
|
||||
self,
|
||||
result: BatchResult,
|
||||
process_func: Callable[[BatchItem], BatchItem],
|
||||
progress_callback: Optional[ProgressCallback] = None,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> None:
|
||||
"""Execute batch processing with thread pool."""
|
||||
completed = 0
|
||||
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(process_func, item): item
|
||||
executor.submit(process_func, item): item
|
||||
for item in result.items
|
||||
}
|
||||
|
||||
|
||||
for future in as_completed(futures):
|
||||
item = future.result()
|
||||
completed += 1
|
||||
|
||||
|
||||
with self._lock:
|
||||
if item.status == BatchStatus.SUCCESS:
|
||||
result.succeeded += 1
|
||||
@@ -458,32 +457,32 @@ class BatchProcessor:
|
||||
result.failed += 1
|
||||
elif item.status == BatchStatus.SKIPPED:
|
||||
result.skipped += 1
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(completed, result.total, item)
|
||||
|
||||
|
||||
result.end_time = time.time()
|
||||
|
||||
|
||||
def _do_encode(
|
||||
self,
|
||||
item: BatchItem,
|
||||
message: Optional[str],
|
||||
file_payload: Optional[Path],
|
||||
message: str | None,
|
||||
file_payload: Path | None,
|
||||
creds: BatchCredentials,
|
||||
compress: bool
|
||||
) -> None:
|
||||
"""
|
||||
Perform actual encoding using stegasoo.encode.
|
||||
|
||||
|
||||
Override this method to customize encoding behavior.
|
||||
"""
|
||||
try:
|
||||
from .encode import encode, encode_file
|
||||
from .encode import encode
|
||||
from .models import FilePayload
|
||||
|
||||
|
||||
# Read carrier image
|
||||
carrier_image = item.input_path.read_bytes()
|
||||
|
||||
|
||||
if file_payload:
|
||||
# Encode file
|
||||
payload = FilePayload.from_file(str(file_payload))
|
||||
@@ -507,15 +506,15 @@ class BatchProcessor:
|
||||
rsa_key_data=creds.rsa_key_data,
|
||||
rsa_password=creds.rsa_password,
|
||||
)
|
||||
|
||||
|
||||
# Write output
|
||||
if item.output_path:
|
||||
item.output_path.write_bytes(result.stego_image)
|
||||
|
||||
|
||||
except ImportError:
|
||||
# Fallback to mock if stegasoo.encode not available
|
||||
self._mock_encode(item, message, creds, compress)
|
||||
|
||||
|
||||
def _do_decode(
|
||||
self,
|
||||
item: BatchItem,
|
||||
@@ -523,15 +522,15 @@ class BatchProcessor:
|
||||
) -> str:
|
||||
"""
|
||||
Perform actual decoding using stegasoo.decode.
|
||||
|
||||
|
||||
Override this method to customize decoding behavior.
|
||||
"""
|
||||
try:
|
||||
from .decode import decode
|
||||
|
||||
|
||||
# Read stego image
|
||||
stego_image = item.input_path.read_bytes()
|
||||
|
||||
|
||||
result = decode(
|
||||
stego_image=stego_image,
|
||||
reference_photo=creds.reference_photo,
|
||||
@@ -540,7 +539,7 @@ class BatchProcessor:
|
||||
rsa_key_data=creds.rsa_key_data,
|
||||
rsa_password=creds.rsa_password,
|
||||
)
|
||||
|
||||
|
||||
if result.is_text:
|
||||
return result.message or ""
|
||||
else:
|
||||
@@ -550,11 +549,11 @@ class BatchProcessor:
|
||||
output_file.write_bytes(result.file_data)
|
||||
return f"File extracted: {result.filename or 'extracted_file'}"
|
||||
return f"[File: {result.filename or 'binary data'}]"
|
||||
|
||||
|
||||
except ImportError:
|
||||
# Fallback to mock if stegasoo.decode not available
|
||||
return self._mock_decode(item, creds)
|
||||
|
||||
|
||||
def _mock_encode(
|
||||
self,
|
||||
item: BatchItem,
|
||||
@@ -568,7 +567,7 @@ class BatchProcessor:
|
||||
import shutil
|
||||
if item.output_path:
|
||||
shutil.copy(item.input_path, item.output_path)
|
||||
|
||||
|
||||
def _mock_decode(self, item: BatchItem, creds: BatchCredentials) -> str:
|
||||
"""Mock decode for testing - replace with actual stego.decode()"""
|
||||
# This is a placeholder - in real usage, you'd call your actual decode function
|
||||
@@ -581,30 +580,31 @@ def batch_capacity_check(
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Check capacity of multiple images without encoding.
|
||||
|
||||
|
||||
Args:
|
||||
images: List of image paths or directories
|
||||
recursive: Search directories recursively
|
||||
|
||||
|
||||
Returns:
|
||||
List of dicts with path, dimensions, and estimated capacity
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
from .constants import MAX_IMAGE_PIXELS
|
||||
|
||||
|
||||
processor = BatchProcessor()
|
||||
results = []
|
||||
|
||||
|
||||
for img_path in processor.find_images(images, recursive):
|
||||
try:
|
||||
with Image.open(img_path) as img:
|
||||
width, height = img.size
|
||||
pixels = width * height
|
||||
|
||||
|
||||
# Estimate: 3 bits per pixel (RGB LSB), minus header overhead
|
||||
capacity_bits = pixels * 3
|
||||
capacity_bytes = (capacity_bits // 8) - 100 # Header overhead
|
||||
|
||||
|
||||
results.append({
|
||||
"path": str(img_path),
|
||||
"dimensions": f"{width}x{height}",
|
||||
@@ -622,25 +622,25 @@ def batch_capacity_check(
|
||||
"error": str(e),
|
||||
"valid": False,
|
||||
})
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_image_warnings(img, path: Path) -> list[str]:
|
||||
"""Generate warnings for an image."""
|
||||
from .constants import MAX_IMAGE_PIXELS, LOSSLESS_FORMATS
|
||||
|
||||
from .constants import LOSSLESS_FORMATS, MAX_IMAGE_PIXELS
|
||||
|
||||
warnings = []
|
||||
|
||||
|
||||
if img.format not in LOSSLESS_FORMATS:
|
||||
warnings.append(f"Lossy format ({img.format}) - quality will degrade on re-save")
|
||||
|
||||
|
||||
if img.size[0] * img.size[1] > MAX_IMAGE_PIXELS:
|
||||
warnings.append(f"Image exceeds {MAX_IMAGE_PIXELS:,} pixel limit")
|
||||
|
||||
|
||||
if img.mode not in ('RGB', 'RGBA'):
|
||||
warnings.append(f"Non-RGB mode ({img.mode}) - will be converted")
|
||||
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
@@ -657,7 +657,7 @@ def print_batch_result(result: BatchResult, verbose: bool = False) -> None:
|
||||
print(f"Skipped: {result.skipped}")
|
||||
if result.duration:
|
||||
print(f"Duration: {result.duration:.2f}s")
|
||||
|
||||
|
||||
if verbose or result.failed > 0:
|
||||
print(f"\n{'─'*60}")
|
||||
for item in result.items:
|
||||
@@ -668,7 +668,7 @@ def print_batch_result(result: BatchResult, verbose: bool = False) -> None:
|
||||
BatchStatus.PENDING: "…",
|
||||
BatchStatus.PROCESSING: "⟳",
|
||||
}.get(item.status, "?")
|
||||
|
||||
|
||||
print(f"{status_icon} {item.input_path.name}")
|
||||
if item.error:
|
||||
print(f" Error: {item.error}")
|
||||
|
||||
@@ -24,12 +24,11 @@ INTEGRATION STATUS (v4.0.0):
|
||||
- ✅ Helpful error messages for channel key mismatches
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from .debug import debug
|
||||
|
||||
@@ -52,10 +51,10 @@ CONFIG_LOCATIONS = [
|
||||
def generate_channel_key() -> str:
|
||||
"""
|
||||
Generate a new random channel key.
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted channel key (e.g., "ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456")
|
||||
|
||||
|
||||
Example:
|
||||
>>> key = generate_channel_key()
|
||||
>>> len(key)
|
||||
@@ -64,7 +63,7 @@ def generate_channel_key() -> str:
|
||||
# Generate 32 random alphanumeric characters
|
||||
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
||||
raw_key = ''.join(secrets.choice(alphabet) for _ in range(CHANNEL_KEY_LENGTH))
|
||||
|
||||
|
||||
formatted = format_channel_key(raw_key)
|
||||
debug.print(f"Generated channel key: {get_channel_fingerprint(formatted)}")
|
||||
return formatted
|
||||
@@ -73,32 +72,32 @@ def generate_channel_key() -> str:
|
||||
def format_channel_key(raw_key: str) -> str:
|
||||
"""
|
||||
Format a raw key string into the standard format.
|
||||
|
||||
|
||||
Args:
|
||||
raw_key: Raw key string (with or without dashes)
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted key with dashes (XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX)
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If key is invalid length or contains invalid characters
|
||||
|
||||
|
||||
Example:
|
||||
>>> format_channel_key("ABCD1234EFGH5678IJKL9012MNOP3456")
|
||||
"ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456"
|
||||
"""
|
||||
# Remove any existing dashes, spaces, and convert to uppercase
|
||||
clean = raw_key.replace('-', '').replace(' ', '').upper()
|
||||
|
||||
|
||||
if len(clean) != CHANNEL_KEY_LENGTH:
|
||||
raise ValueError(
|
||||
f"Channel key must be {CHANNEL_KEY_LENGTH} characters (got {len(clean)})"
|
||||
)
|
||||
|
||||
|
||||
# Validate characters
|
||||
if not all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' for c in clean):
|
||||
raise ValueError("Channel key must contain only letters A-Z and digits 0-9")
|
||||
|
||||
|
||||
# Format with dashes every 4 characters
|
||||
return '-'.join(clean[i:i+4] for i in range(0, CHANNEL_KEY_LENGTH, 4))
|
||||
|
||||
@@ -106,13 +105,13 @@ def format_channel_key(raw_key: str) -> str:
|
||||
def validate_channel_key(key: str) -> bool:
|
||||
"""
|
||||
Validate a channel key format.
|
||||
|
||||
|
||||
Args:
|
||||
key: Channel key to validate
|
||||
|
||||
|
||||
Returns:
|
||||
True if valid format, False otherwise
|
||||
|
||||
|
||||
Example:
|
||||
>>> validate_channel_key("ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456")
|
||||
True
|
||||
@@ -121,7 +120,7 @@ def validate_channel_key(key: str) -> bool:
|
||||
"""
|
||||
if not key:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
formatted = format_channel_key(key)
|
||||
return bool(CHANNEL_KEY_PATTERN.match(formatted))
|
||||
@@ -129,18 +128,18 @@ def validate_channel_key(key: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_channel_key() -> Optional[str]:
|
||||
def get_channel_key() -> str | None:
|
||||
"""
|
||||
Get the current channel key from environment or config.
|
||||
|
||||
|
||||
Checks in order:
|
||||
1. STEGASOO_CHANNEL_KEY environment variable
|
||||
2. ./config/channel.key file
|
||||
3. ~/.stegasoo/channel.key file
|
||||
|
||||
|
||||
Returns:
|
||||
Channel key if configured, None if in public mode
|
||||
|
||||
|
||||
Example:
|
||||
>>> key = get_channel_key()
|
||||
>>> if key:
|
||||
@@ -156,7 +155,7 @@ def get_channel_key() -> Optional[str]:
|
||||
return format_channel_key(env_key)
|
||||
else:
|
||||
debug.print(f"Warning: Invalid {CHANNEL_KEY_ENV_VAR} format, ignoring")
|
||||
|
||||
|
||||
# 2. Check config files
|
||||
for config_path in CONFIG_LOCATIONS:
|
||||
if config_path.exists():
|
||||
@@ -165,10 +164,10 @@ def get_channel_key() -> Optional[str]:
|
||||
if key and validate_channel_key(key):
|
||||
debug.print(f"Channel key from {config_path}: {get_channel_fingerprint(key)}")
|
||||
return format_channel_key(key)
|
||||
except (IOError, PermissionError) as e:
|
||||
except (OSError, PermissionError) as e:
|
||||
debug.print(f"Could not read {config_path}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 3. No channel key configured (public mode)
|
||||
debug.print("No channel key configured (public mode)")
|
||||
return None
|
||||
@@ -177,92 +176,92 @@ def get_channel_key() -> Optional[str]:
|
||||
def set_channel_key(key: str, location: str = 'project') -> Path:
|
||||
"""
|
||||
Save a channel key to config file.
|
||||
|
||||
|
||||
Args:
|
||||
key: Channel key to save (will be formatted)
|
||||
location: 'project' for ./config/ or 'user' for ~/.stegasoo/
|
||||
|
||||
|
||||
Returns:
|
||||
Path where key was saved
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If key format is invalid
|
||||
|
||||
|
||||
Example:
|
||||
>>> path = set_channel_key("ABCD1234EFGH5678IJKL9012MNOP3456")
|
||||
>>> print(path)
|
||||
./config/channel.key
|
||||
"""
|
||||
formatted = format_channel_key(key)
|
||||
|
||||
|
||||
if location == 'user':
|
||||
config_path = Path.home() / '.stegasoo' / 'channel.key'
|
||||
else:
|
||||
config_path = Path('./config/channel.key')
|
||||
|
||||
|
||||
# Create directory if needed
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Write key with newline
|
||||
config_path.write_text(formatted + '\n')
|
||||
|
||||
|
||||
# Set restrictive permissions (owner read/write only)
|
||||
try:
|
||||
config_path.chmod(0o600)
|
||||
except (OSError, AttributeError):
|
||||
pass # Windows doesn't support chmod the same way
|
||||
|
||||
|
||||
debug.print(f"Channel key saved to {config_path}")
|
||||
return config_path
|
||||
|
||||
|
||||
def clear_channel_key(location: str = 'all') -> List[Path]:
|
||||
def clear_channel_key(location: str = 'all') -> list[Path]:
|
||||
"""
|
||||
Remove channel key configuration.
|
||||
|
||||
|
||||
Args:
|
||||
location: 'project', 'user', or 'all'
|
||||
|
||||
|
||||
Returns:
|
||||
List of paths that were deleted
|
||||
|
||||
|
||||
Example:
|
||||
>>> deleted = clear_channel_key('all')
|
||||
>>> print(f"Removed {len(deleted)} files")
|
||||
"""
|
||||
deleted = []
|
||||
|
||||
|
||||
paths_to_check = []
|
||||
if location in ('project', 'all'):
|
||||
paths_to_check.append(Path('./config/channel.key'))
|
||||
if location in ('user', 'all'):
|
||||
paths_to_check.append(Path.home() / '.stegasoo' / 'channel.key')
|
||||
|
||||
|
||||
for path in paths_to_check:
|
||||
if path.exists():
|
||||
try:
|
||||
path.unlink()
|
||||
deleted.append(path)
|
||||
debug.print(f"Removed channel key: {path}")
|
||||
except (IOError, PermissionError) as e:
|
||||
except (OSError, PermissionError) as e:
|
||||
debug.print(f"Could not remove {path}: {e}")
|
||||
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def get_channel_key_hash(key: Optional[str] = None) -> Optional[bytes]:
|
||||
def get_channel_key_hash(key: str | None = None) -> bytes | None:
|
||||
"""
|
||||
Get the channel key as a 32-byte hash suitable for key derivation.
|
||||
|
||||
|
||||
This hash is mixed into the Argon2 key derivation to bind
|
||||
encryption to a specific channel.
|
||||
|
||||
|
||||
Args:
|
||||
key: Channel key (if None, reads from config)
|
||||
|
||||
|
||||
Returns:
|
||||
32-byte SHA-256 hash of channel key, or None if no channel key
|
||||
|
||||
|
||||
Example:
|
||||
>>> hash_bytes = get_channel_key_hash()
|
||||
>>> if hash_bytes:
|
||||
@@ -270,39 +269,39 @@ def get_channel_key_hash(key: Optional[str] = None) -> Optional[bytes]:
|
||||
"""
|
||||
if key is None:
|
||||
key = get_channel_key()
|
||||
|
||||
|
||||
if not key:
|
||||
return None
|
||||
|
||||
|
||||
# Hash the formatted key to get consistent 32 bytes
|
||||
formatted = format_channel_key(key)
|
||||
return hashlib.sha256(formatted.encode('utf-8')).digest()
|
||||
|
||||
|
||||
def get_channel_fingerprint(key: Optional[str] = None) -> Optional[str]:
|
||||
def get_channel_fingerprint(key: str | None = None) -> str | None:
|
||||
"""
|
||||
Get a short fingerprint for display purposes.
|
||||
Shows first and last 4 chars with masked middle.
|
||||
|
||||
|
||||
Args:
|
||||
key: Channel key (if None, reads from config)
|
||||
|
||||
|
||||
Returns:
|
||||
Fingerprint like "ABCD-••••-••••-••••-••••-••••-••••-3456" or None
|
||||
|
||||
|
||||
Example:
|
||||
>>> print(get_channel_fingerprint())
|
||||
ABCD-••••-••••-••••-••••-••••-••••-3456
|
||||
"""
|
||||
if key is None:
|
||||
key = get_channel_key()
|
||||
|
||||
|
||||
if not key:
|
||||
return None
|
||||
|
||||
|
||||
formatted = format_channel_key(key)
|
||||
parts = formatted.split('-')
|
||||
|
||||
|
||||
# Show first and last group, mask the rest
|
||||
masked = [parts[0]] + ['••••'] * 6 + [parts[-1]]
|
||||
return '-'.join(masked)
|
||||
@@ -311,7 +310,7 @@ def get_channel_fingerprint(key: Optional[str] = None) -> Optional[str]:
|
||||
def get_channel_status() -> dict:
|
||||
"""
|
||||
Get comprehensive channel key status.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- mode: 'private' or 'public'
|
||||
@@ -319,14 +318,14 @@ def get_channel_status() -> dict:
|
||||
- fingerprint: masked key or None
|
||||
- source: where key came from or None
|
||||
- key: full key (for export) or None
|
||||
|
||||
|
||||
Example:
|
||||
>>> status = get_channel_status()
|
||||
>>> print(f"Mode: {status['mode']}")
|
||||
Mode: private
|
||||
"""
|
||||
key = get_channel_key()
|
||||
|
||||
|
||||
if key:
|
||||
# Find which source provided the key
|
||||
source = 'unknown'
|
||||
@@ -341,9 +340,9 @@ def get_channel_status() -> dict:
|
||||
if file_key and format_channel_key(file_key) == key:
|
||||
source = str(config_path)
|
||||
break
|
||||
except (IOError, PermissionError):
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
|
||||
|
||||
return {
|
||||
'mode': 'private',
|
||||
'configured': True,
|
||||
@@ -364,10 +363,10 @@ def get_channel_status() -> dict:
|
||||
def has_channel_key() -> bool:
|
||||
"""
|
||||
Quick check if a channel key is configured.
|
||||
|
||||
|
||||
Returns:
|
||||
True if channel key is set, False for public mode
|
||||
|
||||
|
||||
Example:
|
||||
>>> if has_channel_key():
|
||||
... print("Private channel active")
|
||||
@@ -381,7 +380,7 @@ def has_channel_key() -> bool:
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
|
||||
def print_status():
|
||||
"""Print current channel status."""
|
||||
status = get_channel_status()
|
||||
@@ -391,7 +390,7 @@ if __name__ == '__main__':
|
||||
print(f"Source: {status['source']}")
|
||||
else:
|
||||
print("No channel key configured (public mode)")
|
||||
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Channel Key Manager")
|
||||
print("=" * 40)
|
||||
@@ -404,24 +403,24 @@ if __name__ == '__main__':
|
||||
print(" python -m stegasoo.channel clear - Remove channel key")
|
||||
print(" python -m stegasoo.channel status - Show status")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
cmd = sys.argv[1].lower()
|
||||
|
||||
|
||||
if cmd == 'generate':
|
||||
key = generate_channel_key()
|
||||
print(f"Generated channel key:")
|
||||
print("Generated channel key:")
|
||||
print(f" {key}")
|
||||
print()
|
||||
save = input("Save to config? [y/N]: ").strip().lower()
|
||||
if save == 'y':
|
||||
path = set_channel_key(key)
|
||||
print(f"Saved to: {path}")
|
||||
|
||||
|
||||
elif cmd == 'set':
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python -m stegasoo.channel set <KEY>")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
try:
|
||||
key = sys.argv[2]
|
||||
formatted = format_channel_key(key)
|
||||
@@ -431,7 +430,7 @@ if __name__ == '__main__':
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
elif cmd == 'show':
|
||||
status = get_channel_status()
|
||||
if status['configured']:
|
||||
@@ -439,17 +438,17 @@ if __name__ == '__main__':
|
||||
print(f"Source: {status['source']}")
|
||||
else:
|
||||
print("No channel key configured")
|
||||
|
||||
|
||||
elif cmd == 'clear':
|
||||
deleted = clear_channel_key('all')
|
||||
if deleted:
|
||||
print(f"Removed channel key from: {', '.join(str(p) for p in deleted)}")
|
||||
else:
|
||||
print("No channel key files found")
|
||||
|
||||
|
||||
elif cmd == 'status':
|
||||
print_status()
|
||||
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -8,33 +8,29 @@ Changes in v3.2.0:
|
||||
- Updated help text to use 'passphrase' terminology
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from .constants import (
|
||||
__version__,
|
||||
MAX_MESSAGE_SIZE,
|
||||
MAX_FILE_PAYLOAD_SIZE,
|
||||
DEFAULT_PIN_LENGTH,
|
||||
DEFAULT_PASSPHRASE_WORDS, # v3.2.0: renamed from DEFAULT_PHRASE_WORDS
|
||||
)
|
||||
from .compression import (
|
||||
CompressionAlgorithm,
|
||||
get_available_algorithms,
|
||||
algorithm_name,
|
||||
HAS_LZ4,
|
||||
)
|
||||
from .batch import (
|
||||
BatchProcessor,
|
||||
BatchResult,
|
||||
batch_capacity_check,
|
||||
print_batch_result,
|
||||
)
|
||||
|
||||
from .compression import (
|
||||
HAS_LZ4,
|
||||
CompressionAlgorithm,
|
||||
algorithm_name,
|
||||
get_available_algorithms,
|
||||
)
|
||||
from .constants import (
|
||||
DEFAULT_PASSPHRASE_WORDS, # v3.2.0: renamed from DEFAULT_PHRASE_WORDS
|
||||
DEFAULT_PIN_LENGTH,
|
||||
MAX_FILE_PAYLOAD_SIZE,
|
||||
MAX_MESSAGE_SIZE,
|
||||
__version__,
|
||||
)
|
||||
|
||||
# Click context settings
|
||||
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
|
||||
@@ -47,7 +43,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
|
||||
def cli(ctx, json_output):
|
||||
"""
|
||||
Stegasoo - Steganography with hybrid authentication.
|
||||
|
||||
|
||||
Hide messages in images using PIN + passphrase security.
|
||||
"""
|
||||
ctx.ensure_object(dict)
|
||||
@@ -61,35 +57,35 @@ def cli(ctx, json_output):
|
||||
@cli.command()
|
||||
@click.argument('image', type=click.Path(exists=True))
|
||||
@click.option('-m', '--message', help='Message to encode')
|
||||
@click.option('-f', '--file', 'file_payload', type=click.Path(exists=True),
|
||||
@click.option('-f', '--file', 'file_payload', type=click.Path(exists=True),
|
||||
help='File to embed instead of message')
|
||||
@click.option('-o', '--output', type=click.Path(), help='Output image path')
|
||||
@click.option('--passphrase', prompt=True, hide_input=True,
|
||||
@click.option('--passphrase', prompt=True, hide_input=True,
|
||||
confirmation_prompt=True, help='Passphrase (recommend 4+ words)')
|
||||
@click.option('--pin', prompt=True, hide_input=True,
|
||||
confirmation_prompt=True, help='PIN code')
|
||||
@click.option('--compress/--no-compress', default=True,
|
||||
@click.option('--compress/--no-compress', default=True,
|
||||
help='Enable/disable compression (default: enabled)')
|
||||
@click.option('--algorithm', type=click.Choice(['zlib', 'lz4', 'none']),
|
||||
@click.option('--algorithm', type=click.Choice(['zlib', 'lz4', 'none']),
|
||||
default='zlib', help='Compression algorithm')
|
||||
@click.option('--dry-run', is_flag=True, help='Show capacity usage without encoding')
|
||||
@click.pass_context
|
||||
def encode(ctx, image, message, file_payload, output, passphrase, pin,
|
||||
def encode(ctx, image, message, file_payload, output, passphrase, pin,
|
||||
compress, algorithm, dry_run):
|
||||
"""
|
||||
Encode a message or file into an image.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
stegasoo encode photo.png -m "Secret message" --passphrase --pin
|
||||
|
||||
|
||||
stegasoo encode photo.png -f secret.pdf -o encoded.png
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if not message and not file_payload:
|
||||
raise click.UsageError("Either --message or --file is required")
|
||||
|
||||
|
||||
# Parse compression algorithm
|
||||
algo_map = {
|
||||
'zlib': CompressionAlgorithm.ZLIB,
|
||||
@@ -97,11 +93,11 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
|
||||
'none': CompressionAlgorithm.NONE,
|
||||
}
|
||||
compression_algo = algo_map[algorithm] if compress else CompressionAlgorithm.NONE
|
||||
|
||||
|
||||
if algorithm == 'lz4' and not HAS_LZ4:
|
||||
click.echo("Warning: LZ4 not available, falling back to zlib", err=True)
|
||||
compression_algo = CompressionAlgorithm.ZLIB
|
||||
|
||||
|
||||
# Calculate payload size
|
||||
if file_payload:
|
||||
payload_size = Path(file_payload).stat().st_size
|
||||
@@ -109,12 +105,12 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
|
||||
else:
|
||||
payload_size = len(message.encode('utf-8'))
|
||||
payload_type = "text"
|
||||
|
||||
|
||||
# Get image capacity
|
||||
with Image.open(image) as img:
|
||||
width, height = img.size
|
||||
capacity_bytes = (width * height * 3 // 8) - 69 # v3.2.0: corrected overhead
|
||||
|
||||
|
||||
if dry_run:
|
||||
result = {
|
||||
"image": image,
|
||||
@@ -126,7 +122,7 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
|
||||
"usage_percent": round(payload_size / capacity_bytes * 100, 1),
|
||||
"fits": payload_size < capacity_bytes,
|
||||
}
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(json.dumps(result, indent=2))
|
||||
else:
|
||||
@@ -137,11 +133,11 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
|
||||
click.echo(f"Usage: {result['usage_percent']}%")
|
||||
click.echo(f"Status: {'✓ Fits' if result['fits'] else '✗ Too large'}")
|
||||
return
|
||||
|
||||
|
||||
# Actual encoding would happen here
|
||||
# For now, show what would be done
|
||||
output = output or f"{Path(image).stem}_encoded.png"
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(json.dumps({
|
||||
"status": "success",
|
||||
@@ -159,17 +155,17 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
|
||||
@click.argument('image', type=click.Path(exists=True))
|
||||
@click.option('--passphrase', prompt=True, hide_input=True, help='Passphrase')
|
||||
@click.option('--pin', prompt=True, hide_input=True, help='PIN code')
|
||||
@click.option('-o', '--output', type=click.Path(),
|
||||
@click.option('-o', '--output', type=click.Path(),
|
||||
help='Output path for file payloads')
|
||||
@click.pass_context
|
||||
def decode(ctx, image, passphrase, pin, output):
|
||||
"""
|
||||
Decode a message or file from an image.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
stegasoo decode encoded.png --passphrase --pin
|
||||
|
||||
|
||||
stegasoo decode encoded.png -o ./extracted/
|
||||
"""
|
||||
# Actual decoding would happen here
|
||||
@@ -179,7 +175,7 @@ def decode(ctx, image, passphrase, pin, output):
|
||||
"payload_type": "text",
|
||||
"message": "[Decoded message would appear here]",
|
||||
}
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(json.dumps(result, indent=2))
|
||||
else:
|
||||
@@ -222,27 +218,27 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
|
||||
passphrase, pin, compress, algorithm, recursive, jobs, verbose):
|
||||
"""
|
||||
Encode message into multiple images.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
stegasoo batch encode *.png -m "Secret" --passphrase --pin
|
||||
|
||||
|
||||
stegasoo batch encode ./photos/ -r -o ./encoded/
|
||||
"""
|
||||
if not message and not file_payload:
|
||||
raise click.UsageError("Either --message or --file is required")
|
||||
|
||||
|
||||
processor = BatchProcessor(max_workers=jobs)
|
||||
|
||||
|
||||
# Progress callback
|
||||
def progress(current, total, item):
|
||||
if not ctx.obj.get('json'):
|
||||
status = "✓" if item.status.value == "success" else "✗"
|
||||
click.echo(f"[{current}/{total}] {status} {item.input_path.name}")
|
||||
|
||||
|
||||
# v3.2.0: Use 'passphrase' key instead of 'phrase'
|
||||
credentials = {"passphrase": passphrase, "pin": pin}
|
||||
|
||||
|
||||
result = processor.batch_encode(
|
||||
images=list(images),
|
||||
message=message,
|
||||
@@ -254,7 +250,7 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
|
||||
recursive=recursive,
|
||||
progress_callback=progress if not ctx.obj.get('json') else None,
|
||||
)
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(result.to_json())
|
||||
else:
|
||||
@@ -275,24 +271,24 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
|
||||
def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verbose):
|
||||
"""
|
||||
Decode messages from multiple images.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
stegasoo batch decode encoded*.png --passphrase --pin
|
||||
|
||||
|
||||
stegasoo batch decode ./encoded/ -r -o ./extracted/
|
||||
"""
|
||||
processor = BatchProcessor(max_workers=jobs)
|
||||
|
||||
|
||||
# Progress callback
|
||||
def progress(current, total, item):
|
||||
if not ctx.obj.get('json'):
|
||||
status = "✓" if item.status.value == "success" else "✗"
|
||||
click.echo(f"[{current}/{total}] {status} {item.input_path.name}")
|
||||
|
||||
|
||||
# v3.2.0: Use 'passphrase' key instead of 'phrase'
|
||||
credentials = {"passphrase": passphrase, "pin": pin}
|
||||
|
||||
|
||||
result = processor.batch_decode(
|
||||
images=list(images),
|
||||
output_dir=Path(output_dir) if output_dir else None,
|
||||
@@ -300,7 +296,7 @@ def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verb
|
||||
recursive=recursive,
|
||||
progress_callback=progress if not ctx.obj.get('json') else None,
|
||||
)
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(result.to_json())
|
||||
else:
|
||||
@@ -315,21 +311,21 @@ def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verb
|
||||
def batch_check(ctx, images, recursive):
|
||||
"""
|
||||
Check capacity of multiple images.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
stegasoo batch check *.png
|
||||
|
||||
|
||||
stegasoo batch check ./photos/ -r
|
||||
"""
|
||||
results = batch_capacity_check(list(images), recursive)
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(json.dumps(results, indent=2))
|
||||
else:
|
||||
click.echo(f"{'Image':<40} {'Size':<12} {'Capacity':<12} {'Status'}")
|
||||
click.echo("─" * 80)
|
||||
|
||||
|
||||
for item in results:
|
||||
if 'error' in item:
|
||||
click.echo(f"{Path(item['path']).name:<40} {'ERROR':<12} {'':<12} {item['error']}")
|
||||
@@ -337,10 +333,10 @@ def batch_check(ctx, images, recursive):
|
||||
name = Path(item['path']).name
|
||||
if len(name) > 38:
|
||||
name = name[:35] + "..."
|
||||
|
||||
|
||||
status = "✓" if item['valid'] else "⚠"
|
||||
warnings = ", ".join(item.get('warnings', []))
|
||||
|
||||
|
||||
click.echo(
|
||||
f"{name:<40} "
|
||||
f"{item['dimensions']:<12} "
|
||||
@@ -354,7 +350,7 @@ def batch_check(ctx, images, recursive):
|
||||
# =============================================================================
|
||||
|
||||
@cli.command()
|
||||
@click.option('--words', default=DEFAULT_PASSPHRASE_WORDS,
|
||||
@click.option('--words', default=DEFAULT_PASSPHRASE_WORDS,
|
||||
help=f'Number of words in passphrase (default: {DEFAULT_PASSPHRASE_WORDS})')
|
||||
@click.option('--pin-length', default=DEFAULT_PIN_LENGTH,
|
||||
help=f'PIN length (default: {DEFAULT_PIN_LENGTH})')
|
||||
@@ -362,21 +358,21 @@ def batch_check(ctx, images, recursive):
|
||||
def generate(ctx, words, pin_length):
|
||||
"""
|
||||
Generate random credentials (passphrase + PIN).
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
stegasoo generate
|
||||
|
||||
|
||||
stegasoo generate --words 6 --pin-length 8
|
||||
"""
|
||||
import secrets
|
||||
|
||||
|
||||
# Generate PIN
|
||||
pin = ''.join(str(secrets.randbelow(10)) for _ in range(pin_length))
|
||||
# Ensure PIN doesn't start with 0
|
||||
if pin[0] == '0':
|
||||
pin = str(secrets.randbelow(9) + 1) + pin[1:]
|
||||
|
||||
|
||||
# Generate passphrase (would use BIP-39 wordlist)
|
||||
# Placeholder - actual implementation uses constants.get_wordlist()
|
||||
try:
|
||||
@@ -388,16 +384,16 @@ def generate(ctx, words, pin_length):
|
||||
sample_words = ['alpha', 'bravo', 'charlie', 'delta', 'echo', 'foxtrot',
|
||||
'golf', 'hotel', 'india', 'juliet', 'kilo', 'lima']
|
||||
phrase_words = [secrets.choice(sample_words) for _ in range(words)]
|
||||
|
||||
|
||||
passphrase = ' '.join(phrase_words)
|
||||
|
||||
|
||||
result = {
|
||||
"passphrase": passphrase,
|
||||
"pin": pin,
|
||||
"passphrase_words": words,
|
||||
"pin_length": pin_length,
|
||||
}
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(json.dumps(result, indent=2))
|
||||
else:
|
||||
@@ -421,17 +417,17 @@ def info(ctx):
|
||||
"max_file_payload_bytes": MAX_FILE_PAYLOAD_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if ctx.obj.get('json'):
|
||||
click.echo(json.dumps(info_data, indent=2))
|
||||
else:
|
||||
click.echo(f"Stegasoo v{__version__}")
|
||||
click.echo(f"\nCompression algorithms:")
|
||||
click.echo("\nCompression algorithms:")
|
||||
for algo in get_available_algorithms():
|
||||
click.echo(f" • {algorithm_name(algo)}")
|
||||
if not HAS_LZ4:
|
||||
click.echo(" (install 'lz4' for LZ4 support)")
|
||||
click.echo(f"\nLimits:")
|
||||
click.echo("\nLimits:")
|
||||
click.echo(f" • Max message: {MAX_MESSAGE_SIZE:,} bytes")
|
||||
click.echo(f" • Max file payload: {MAX_FILE_PAYLOAD_SIZE:,} bytes")
|
||||
|
||||
|
||||
@@ -5,10 +5,9 @@ Provides transparent compression/decompression for payloads before encryption.
|
||||
Supports multiple algorithms with automatic detection on decompression.
|
||||
"""
|
||||
|
||||
import zlib
|
||||
import struct
|
||||
import zlib
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
# Optional LZ4 support (faster, slightly worse ratio)
|
||||
try:
|
||||
@@ -43,26 +42,26 @@ class CompressionError(Exception):
|
||||
def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm.ZLIB) -> bytes:
|
||||
"""
|
||||
Compress data with specified algorithm.
|
||||
|
||||
|
||||
Format: MAGIC (4) + ALGORITHM (1) + ORIGINAL_SIZE (4) + COMPRESSED_DATA
|
||||
|
||||
|
||||
Args:
|
||||
data: Raw bytes to compress
|
||||
algorithm: Compression algorithm to use
|
||||
|
||||
|
||||
Returns:
|
||||
Compressed data with header, or original data if compression didn't help
|
||||
"""
|
||||
if len(data) < MIN_COMPRESS_SIZE:
|
||||
# Too small to benefit from compression
|
||||
return _wrap_uncompressed(data)
|
||||
|
||||
|
||||
if algorithm == CompressionAlgorithm.NONE:
|
||||
return _wrap_uncompressed(data)
|
||||
|
||||
|
||||
elif algorithm == CompressionAlgorithm.ZLIB:
|
||||
compressed = zlib.compress(data, level=ZLIB_LEVEL)
|
||||
|
||||
|
||||
elif algorithm == CompressionAlgorithm.LZ4:
|
||||
if not HAS_LZ4:
|
||||
# Fall back to zlib if LZ4 not available
|
||||
@@ -72,11 +71,11 @@ def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm
|
||||
compressed = lz4.frame.compress(data)
|
||||
else:
|
||||
raise CompressionError(f"Unknown compression algorithm: {algorithm}")
|
||||
|
||||
|
||||
# Only use compression if it actually reduced size
|
||||
if len(compressed) >= len(data):
|
||||
return _wrap_uncompressed(data)
|
||||
|
||||
|
||||
# Build header: MAGIC + algorithm + original_size + compressed_data
|
||||
header = COMPRESSION_MAGIC + struct.pack('<BI', algorithm, len(data))
|
||||
return header + compressed
|
||||
@@ -85,10 +84,10 @@ def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm
|
||||
def decompress(data: bytes) -> bytes:
|
||||
"""
|
||||
Decompress data, auto-detecting algorithm from header.
|
||||
|
||||
|
||||
Args:
|
||||
data: Potentially compressed data
|
||||
|
||||
|
||||
Returns:
|
||||
Decompressed data (or original if not compressed)
|
||||
"""
|
||||
@@ -96,24 +95,24 @@ def decompress(data: bytes) -> bytes:
|
||||
if not data.startswith(COMPRESSION_MAGIC):
|
||||
# Not compressed by us, return as-is
|
||||
return data
|
||||
|
||||
|
||||
if len(data) < 9: # MAGIC(4) + ALGO(1) + SIZE(4)
|
||||
raise CompressionError("Truncated compression header")
|
||||
|
||||
|
||||
# Parse header
|
||||
algorithm = CompressionAlgorithm(data[4])
|
||||
original_size = struct.unpack('<I', data[5:9])[0]
|
||||
compressed_data = data[9:]
|
||||
|
||||
|
||||
if algorithm == CompressionAlgorithm.NONE:
|
||||
result = compressed_data
|
||||
|
||||
|
||||
elif algorithm == CompressionAlgorithm.ZLIB:
|
||||
try:
|
||||
result = zlib.decompress(compressed_data)
|
||||
except zlib.error as e:
|
||||
raise CompressionError(f"Zlib decompression failed: {e}")
|
||||
|
||||
|
||||
elif algorithm == CompressionAlgorithm.LZ4:
|
||||
if not HAS_LZ4:
|
||||
raise CompressionError("LZ4 compression used but lz4 package not installed")
|
||||
@@ -123,13 +122,13 @@ def decompress(data: bytes) -> bytes:
|
||||
raise CompressionError(f"LZ4 decompression failed: {e}")
|
||||
else:
|
||||
raise CompressionError(f"Unknown compression algorithm: {algorithm}")
|
||||
|
||||
|
||||
# Verify size
|
||||
if len(result) != original_size:
|
||||
raise CompressionError(
|
||||
f"Size mismatch: expected {original_size}, got {len(result)}"
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -142,7 +141,7 @@ def _wrap_uncompressed(data: bytes) -> bytes:
|
||||
def get_compression_ratio(original: bytes, compressed: bytes) -> float:
|
||||
"""
|
||||
Calculate compression ratio.
|
||||
|
||||
|
||||
Returns:
|
||||
Ratio where < 1.0 means compression helped, > 1.0 means it expanded
|
||||
"""
|
||||
@@ -155,36 +154,36 @@ def estimate_compressed_size(data: bytes, algorithm: CompressionAlgorithm = Comp
|
||||
"""
|
||||
Estimate compressed size without full compression.
|
||||
Uses sampling for large data.
|
||||
|
||||
|
||||
Args:
|
||||
data: Data to estimate
|
||||
algorithm: Algorithm to estimate for
|
||||
|
||||
|
||||
Returns:
|
||||
Estimated compressed size in bytes
|
||||
"""
|
||||
if len(data) < MIN_COMPRESS_SIZE:
|
||||
return len(data) + 9 # Header overhead
|
||||
|
||||
|
||||
# For small data, just compress it
|
||||
if len(data) < 10000:
|
||||
compressed = compress(data, algorithm)
|
||||
return len(compressed)
|
||||
|
||||
|
||||
# For large data, sample and extrapolate
|
||||
sample_size = 8192
|
||||
sample = data[:sample_size]
|
||||
|
||||
|
||||
if algorithm == CompressionAlgorithm.ZLIB:
|
||||
compressed_sample = zlib.compress(sample, level=ZLIB_LEVEL)
|
||||
elif algorithm == CompressionAlgorithm.LZ4 and HAS_LZ4:
|
||||
compressed_sample = lz4.frame.compress(sample)
|
||||
else:
|
||||
compressed_sample = zlib.compress(sample, level=ZLIB_LEVEL)
|
||||
|
||||
|
||||
ratio = len(compressed_sample) / len(sample)
|
||||
estimated = int(len(data) * ratio) + 9 # Add header
|
||||
|
||||
|
||||
return estimated
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ BREAKING CHANGES in v3.2.0:
|
||||
- Renamed day_phrase → passphrase throughout codebase
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# ============================================================================
|
||||
@@ -89,7 +88,7 @@ RECOMMENDED_PASSPHRASE_WORDS = 4 # Best practice guideline
|
||||
|
||||
# Legacy aliases for backward compatibility during transition
|
||||
MIN_PHRASE_WORDS = MIN_PASSPHRASE_WORDS
|
||||
MAX_PHRASE_WORDS = MAX_PASSPHRASE_WORDS
|
||||
MAX_PHRASE_WORDS = MAX_PASSPHRASE_WORDS
|
||||
DEFAULT_PHRASE_WORDS = DEFAULT_PASSPHRASE_WORDS
|
||||
|
||||
# RSA configuration
|
||||
@@ -180,11 +179,11 @@ def get_data_dir() -> Path:
|
||||
Path.cwd().parent / 'data', # One level up from cwd
|
||||
Path.cwd().parent.parent / 'data', # Two levels up from cwd
|
||||
]
|
||||
|
||||
|
||||
for path in candidates:
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
|
||||
# Default to first candidate
|
||||
return candidates[0]
|
||||
|
||||
@@ -192,14 +191,14 @@ def get_data_dir() -> Path:
|
||||
def get_bip39_words() -> list[str]:
|
||||
"""Load BIP-39 wordlist."""
|
||||
wordlist_path = get_data_dir() / 'bip39-words.txt'
|
||||
|
||||
|
||||
if not wordlist_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"BIP-39 wordlist not found at {wordlist_path}. "
|
||||
"Please ensure bip39-words.txt is in the data directory."
|
||||
)
|
||||
|
||||
with open(wordlist_path, 'r') as f:
|
||||
|
||||
with open(wordlist_path) as f:
|
||||
return [line.strip() for line in f if line.strip()]
|
||||
|
||||
|
||||
@@ -240,18 +239,18 @@ DCT_BYTES_PER_PIXEL = 0.125 # Approximate for DCT mode (varies by implementatio
|
||||
def detect_stego_mode(encrypted_data: bytes) -> str:
|
||||
"""
|
||||
Detect embedding mode from encrypted payload header.
|
||||
|
||||
|
||||
Args:
|
||||
encrypted_data: First few bytes of extracted payload
|
||||
|
||||
|
||||
Returns:
|
||||
'lsb' or 'dct' or 'unknown'
|
||||
"""
|
||||
if len(encrypted_data) < 4:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
header = encrypted_data[:4]
|
||||
|
||||
|
||||
if header == b'\x89ST3':
|
||||
return EMBED_MODE_LSB
|
||||
elif header == b'\x89DCT':
|
||||
|
||||
@@ -15,38 +15,40 @@ BREAKING CHANGES in v3.2.0:
|
||||
- Renamed day_phrase → passphrase (no daily rotation needed)
|
||||
"""
|
||||
|
||||
import io
|
||||
import hashlib
|
||||
import io
|
||||
import secrets
|
||||
import struct
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from PIL import Image
|
||||
|
||||
from .constants import (
|
||||
MAGIC_HEADER, FORMAT_VERSION,
|
||||
SALT_SIZE, IV_SIZE, TAG_SIZE,
|
||||
ARGON2_TIME_COST, ARGON2_MEMORY_COST, ARGON2_PARALLELISM,
|
||||
PBKDF2_ITERATIONS,
|
||||
PAYLOAD_TEXT, PAYLOAD_FILE,
|
||||
ARGON2_MEMORY_COST,
|
||||
ARGON2_PARALLELISM,
|
||||
ARGON2_TIME_COST,
|
||||
FORMAT_VERSION,
|
||||
IV_SIZE,
|
||||
MAGIC_HEADER,
|
||||
MAX_FILENAME_LENGTH,
|
||||
PAYLOAD_FILE,
|
||||
PAYLOAD_TEXT,
|
||||
PBKDF2_ITERATIONS,
|
||||
SALT_SIZE,
|
||||
TAG_SIZE,
|
||||
)
|
||||
from .models import FilePayload, DecodeResult
|
||||
from .exceptions import (
|
||||
EncryptionError, DecryptionError, KeyDerivationError, InvalidHeaderError
|
||||
)
|
||||
from .exceptions import DecryptionError, EncryptionError, InvalidHeaderError, KeyDerivationError
|
||||
from .models import DecodeResult, FilePayload
|
||||
|
||||
# Check for Argon2 availability
|
||||
try:
|
||||
from argon2.low_level import hash_secret_raw, Type
|
||||
from argon2.low_level import Type, hash_secret_raw
|
||||
HAS_ARGON2 = True
|
||||
except ImportError:
|
||||
HAS_ARGON2 = False
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -57,28 +59,28 @@ except ImportError:
|
||||
CHANNEL_KEY_AUTO = "auto"
|
||||
|
||||
|
||||
def _resolve_channel_key(channel_key: Optional[Union[str, bool]]) -> Optional[bytes]:
|
||||
def _resolve_channel_key(channel_key: str | bool | None) -> bytes | None:
|
||||
"""
|
||||
Resolve channel key parameter to actual key hash.
|
||||
|
||||
|
||||
Args:
|
||||
channel_key: Channel key parameter with these behaviors:
|
||||
- None or "auto": Use server's configured key (from env/config)
|
||||
- str (valid key): Use this specific key
|
||||
- "" or False: Explicitly use NO channel key (public mode)
|
||||
|
||||
|
||||
Returns:
|
||||
32-byte channel key hash, or None for public mode
|
||||
"""
|
||||
# Explicit public mode
|
||||
if channel_key == "" or channel_key is False:
|
||||
return None
|
||||
|
||||
|
||||
# Auto-detect from environment/config
|
||||
if channel_key is None or channel_key == CHANNEL_KEY_AUTO:
|
||||
from .channel import get_channel_key_hash
|
||||
return get_channel_key_hash()
|
||||
|
||||
|
||||
# Explicit key provided - validate and hash it
|
||||
if isinstance(channel_key, str):
|
||||
from .channel import format_channel_key, validate_channel_key
|
||||
@@ -86,7 +88,7 @@ def _resolve_channel_key(channel_key: Optional[Union[str, bool]]) -> Optional[by
|
||||
raise ValueError(f"Invalid channel key format: {channel_key}")
|
||||
formatted = format_channel_key(channel_key)
|
||||
return hashlib.sha256(formatted.encode('utf-8')).digest()
|
||||
|
||||
|
||||
raise ValueError(f"Invalid channel_key type: {type(channel_key)}")
|
||||
|
||||
|
||||
@@ -97,19 +99,19 @@ def _resolve_channel_key(channel_key: Optional[Union[str, bool]]) -> Optional[by
|
||||
def hash_photo(image_data: bytes) -> bytes:
|
||||
"""
|
||||
Compute deterministic hash of photo pixel content.
|
||||
|
||||
|
||||
This normalizes the image to RGB and hashes the raw pixel data,
|
||||
making it resistant to metadata changes.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Raw image file bytes
|
||||
|
||||
|
||||
Returns:
|
||||
32-byte SHA-256 hash
|
||||
"""
|
||||
img: Image.Image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
||||
pixels = img.tobytes()
|
||||
|
||||
|
||||
# Double-hash with prefix for additional mixing
|
||||
h = hashlib.sha256(pixels).digest()
|
||||
h = hashlib.sha256(h + pixels[:1024]).digest()
|
||||
@@ -121,12 +123,12 @@ def derive_hybrid_key(
|
||||
passphrase: str,
|
||||
salt: bytes,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Derive encryption key from multiple factors.
|
||||
|
||||
|
||||
Combines:
|
||||
- Photo hash (something you have)
|
||||
- Passphrase (something you know)
|
||||
@@ -134,9 +136,9 @@ def derive_hybrid_key(
|
||||
- RSA key (something you have)
|
||||
- Channel key (deployment/group binding)
|
||||
- Salt (random per message)
|
||||
|
||||
|
||||
Uses Argon2id if available, falls back to PBKDF2.
|
||||
|
||||
|
||||
Args:
|
||||
photo_data: Reference photo bytes
|
||||
passphrase: Shared passphrase (recommend 4+ words)
|
||||
@@ -147,19 +149,19 @@ def derive_hybrid_key(
|
||||
- None or "auto": Use configured key
|
||||
- str: Use this specific key
|
||||
- "" or False: No channel key (public mode)
|
||||
|
||||
|
||||
Returns:
|
||||
32-byte derived key
|
||||
|
||||
|
||||
Raises:
|
||||
KeyDerivationError: If key derivation fails
|
||||
"""
|
||||
try:
|
||||
photo_hash = hash_photo(photo_data)
|
||||
|
||||
|
||||
# Resolve channel key
|
||||
channel_hash = _resolve_channel_key(channel_key)
|
||||
|
||||
|
||||
# Build key material
|
||||
key_material = (
|
||||
photo_hash +
|
||||
@@ -167,15 +169,15 @@ def derive_hybrid_key(
|
||||
pin.encode() +
|
||||
salt
|
||||
)
|
||||
|
||||
|
||||
# Add RSA key hash if provided
|
||||
if rsa_key_data:
|
||||
key_material += hashlib.sha256(rsa_key_data).digest()
|
||||
|
||||
|
||||
# Add channel key hash if configured (v4.0.0)
|
||||
if channel_hash:
|
||||
key_material += channel_hash
|
||||
|
||||
|
||||
if HAS_ARGON2:
|
||||
key = hash_secret_raw(
|
||||
secret=key_material,
|
||||
@@ -195,9 +197,9 @@ def derive_hybrid_key(
|
||||
backend=default_backend()
|
||||
)
|
||||
key = kdf.derive(key_material)
|
||||
|
||||
|
||||
return key
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise KeyDerivationError(f"Failed to derive key: {e}") from e
|
||||
|
||||
@@ -206,61 +208,61 @@ def derive_pixel_key(
|
||||
photo_data: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Derive key for pseudo-random pixel selection.
|
||||
|
||||
|
||||
This key determines which pixels are used for embedding,
|
||||
making the message location unpredictable without the correct inputs.
|
||||
|
||||
|
||||
Args:
|
||||
photo_data: Reference photo bytes
|
||||
passphrase: Shared passphrase
|
||||
pin: Optional static PIN
|
||||
rsa_key_data: Optional RSA key bytes
|
||||
channel_key: Channel key parameter (see derive_hybrid_key)
|
||||
|
||||
|
||||
Returns:
|
||||
32-byte key for pixel selection
|
||||
"""
|
||||
photo_hash = hash_photo(photo_data)
|
||||
|
||||
|
||||
# Resolve channel key
|
||||
channel_hash = _resolve_channel_key(channel_key)
|
||||
|
||||
|
||||
material = (
|
||||
photo_hash +
|
||||
passphrase.lower().encode() +
|
||||
pin.encode()
|
||||
)
|
||||
|
||||
|
||||
if rsa_key_data:
|
||||
material += hashlib.sha256(rsa_key_data).digest()
|
||||
|
||||
|
||||
# Add channel key hash if configured (v4.0.0)
|
||||
if channel_hash:
|
||||
material += channel_hash
|
||||
|
||||
|
||||
return hashlib.sha256(material + b"pixel_selection").digest()
|
||||
|
||||
|
||||
def _pack_payload(
|
||||
content: Union[str, bytes, FilePayload],
|
||||
content: str | bytes | FilePayload,
|
||||
) -> tuple[bytes, int]:
|
||||
"""
|
||||
Pack payload with type marker and metadata.
|
||||
|
||||
|
||||
Format for text:
|
||||
[type:1][data]
|
||||
|
||||
|
||||
Format for file:
|
||||
[type:1][filename_len:2][filename][mime_len:2][mime][data]
|
||||
|
||||
|
||||
Args:
|
||||
content: Text string, raw bytes, or FilePayload
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (packed bytes, payload type)
|
||||
"""
|
||||
@@ -268,12 +270,12 @@ def _pack_payload(
|
||||
# Text message
|
||||
data = content.encode('utf-8')
|
||||
return bytes([PAYLOAD_TEXT]) + data, PAYLOAD_TEXT
|
||||
|
||||
|
||||
elif isinstance(content, FilePayload):
|
||||
# File with metadata
|
||||
filename = content.filename[:MAX_FILENAME_LENGTH].encode('utf-8')
|
||||
mime = (content.mime_type or '')[:100].encode('utf-8')
|
||||
|
||||
|
||||
packed = (
|
||||
bytes([PAYLOAD_FILE]) +
|
||||
struct.pack('>H', len(filename)) +
|
||||
@@ -283,7 +285,7 @@ def _pack_payload(
|
||||
content.data
|
||||
)
|
||||
return packed, PAYLOAD_FILE
|
||||
|
||||
|
||||
else:
|
||||
# Raw bytes - treat as file with no name
|
||||
packed = (
|
||||
@@ -298,49 +300,49 @@ def _pack_payload(
|
||||
def _unpack_payload(data: bytes) -> DecodeResult:
|
||||
"""
|
||||
Unpack payload and extract content with metadata.
|
||||
|
||||
|
||||
Args:
|
||||
data: Packed payload bytes
|
||||
|
||||
|
||||
Returns:
|
||||
DecodeResult with appropriate content
|
||||
"""
|
||||
if len(data) < 1:
|
||||
raise DecryptionError("Empty payload")
|
||||
|
||||
|
||||
payload_type = data[0]
|
||||
|
||||
|
||||
if payload_type == PAYLOAD_TEXT:
|
||||
# Text message
|
||||
text = data[1:].decode('utf-8')
|
||||
return DecodeResult(payload_type='text', message=text)
|
||||
|
||||
|
||||
elif payload_type == PAYLOAD_FILE:
|
||||
# File with metadata
|
||||
offset = 1
|
||||
|
||||
|
||||
# Read filename
|
||||
filename_len = struct.unpack('>H', data[offset:offset+2])[0]
|
||||
offset += 2
|
||||
filename = data[offset:offset+filename_len].decode('utf-8') if filename_len else None
|
||||
offset += filename_len
|
||||
|
||||
|
||||
# Read mime type
|
||||
mime_len = struct.unpack('>H', data[offset:offset+2])[0]
|
||||
offset += 2
|
||||
mime_type = data[offset:offset+mime_len].decode('utf-8') if mime_len else None
|
||||
offset += mime_len
|
||||
|
||||
|
||||
# Rest is file data
|
||||
file_data = data[offset:]
|
||||
|
||||
|
||||
return DecodeResult(
|
||||
payload_type='file',
|
||||
file_data=file_data,
|
||||
filename=filename,
|
||||
mime_type=mime_type
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# Unknown type - try to decode as text (backward compatibility)
|
||||
try:
|
||||
@@ -359,16 +361,16 @@ FLAG_CHANNEL_KEY = 0x01 # Set if encoded with a channel key
|
||||
|
||||
|
||||
def encrypt_message(
|
||||
message: Union[str, bytes, FilePayload],
|
||||
message: str | bytes | FilePayload,
|
||||
photo_data: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Encrypt message or file using AES-256-GCM with hybrid key derivation.
|
||||
|
||||
|
||||
Message format (v4.0.0 - with channel key support):
|
||||
- Magic header (4 bytes)
|
||||
- Version (1 byte) = 5
|
||||
@@ -377,7 +379,7 @@ def encrypt_message(
|
||||
- IV (12 bytes)
|
||||
- Auth tag (16 bytes)
|
||||
- Ciphertext (variable, padded)
|
||||
|
||||
|
||||
Args:
|
||||
message: Message string, raw bytes, or FilePayload to encrypt
|
||||
photo_data: Reference photo bytes
|
||||
@@ -386,12 +388,12 @@ def encrypt_message(
|
||||
rsa_key_data: Optional RSA key bytes
|
||||
channel_key: Channel key parameter:
|
||||
- None or "auto": Use configured key
|
||||
- str: Use this specific key
|
||||
- str: Use this specific key
|
||||
- "" or False: No channel key (public mode)
|
||||
|
||||
|
||||
Returns:
|
||||
Encrypted message bytes
|
||||
|
||||
|
||||
Raises:
|
||||
EncryptionError: If encryption fails
|
||||
"""
|
||||
@@ -399,32 +401,32 @@ def encrypt_message(
|
||||
salt = secrets.token_bytes(SALT_SIZE)
|
||||
key = derive_hybrid_key(photo_data, passphrase, salt, pin, rsa_key_data, channel_key)
|
||||
iv = secrets.token_bytes(IV_SIZE)
|
||||
|
||||
|
||||
# Determine flags
|
||||
flags = 0
|
||||
channel_hash = _resolve_channel_key(channel_key)
|
||||
if channel_hash:
|
||||
flags |= FLAG_CHANNEL_KEY
|
||||
|
||||
|
||||
# Pack payload with type marker
|
||||
packed_payload, _ = _pack_payload(message)
|
||||
|
||||
|
||||
# Random padding to hide message length
|
||||
padding_len = secrets.randbelow(256) + 64
|
||||
padded_len = ((len(packed_payload) + padding_len + 255) // 256) * 256
|
||||
padding_needed = padded_len - len(packed_payload)
|
||||
padding = secrets.token_bytes(padding_needed - 4) + struct.pack('>I', len(packed_payload))
|
||||
padded_message = packed_payload + padding
|
||||
|
||||
|
||||
# Build header for AAD
|
||||
header = MAGIC_HEADER + bytes([FORMAT_VERSION, flags])
|
||||
|
||||
|
||||
# Encrypt with AES-256-GCM
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encryptor.authenticate_additional_data(header)
|
||||
ciphertext = encryptor.update(padded_message) + encryptor.finalize()
|
||||
|
||||
|
||||
# v4.0.0: Header with flags byte
|
||||
return (
|
||||
header +
|
||||
@@ -433,34 +435,34 @@ def encrypt_message(
|
||||
encryptor.tag +
|
||||
ciphertext
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Encryption failed: {e}") from e
|
||||
|
||||
|
||||
def parse_header(encrypted_data: bytes) -> Optional[dict]:
|
||||
def parse_header(encrypted_data: bytes) -> dict | None:
|
||||
"""
|
||||
Parse the header from encrypted data.
|
||||
|
||||
|
||||
v4.0.0: Includes flags byte for channel key indicator.
|
||||
|
||||
|
||||
Args:
|
||||
encrypted_data: Raw encrypted bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with salt, iv, tag, ciphertext, flags or None if invalid
|
||||
"""
|
||||
# Min size: Magic(4) + Version(1) + Flags(1) + Salt(32) + IV(12) + Tag(16) = 66 bytes
|
||||
if len(encrypted_data) < 66 or encrypted_data[:4] != MAGIC_HEADER:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
version = encrypted_data[4]
|
||||
if version != FORMAT_VERSION:
|
||||
return None
|
||||
|
||||
|
||||
flags = encrypted_data[5]
|
||||
|
||||
|
||||
offset = 6
|
||||
salt = encrypted_data[offset:offset + SALT_SIZE]
|
||||
offset += SALT_SIZE
|
||||
@@ -469,7 +471,7 @@ def parse_header(encrypted_data: bytes) -> Optional[dict]:
|
||||
tag = encrypted_data[offset:offset + TAG_SIZE]
|
||||
offset += TAG_SIZE
|
||||
ciphertext = encrypted_data[offset:]
|
||||
|
||||
|
||||
return {
|
||||
'version': version,
|
||||
'flags': flags,
|
||||
@@ -488,12 +490,12 @@ def decrypt_message(
|
||||
photo_data: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> DecodeResult:
|
||||
"""
|
||||
Decrypt message (v4.0.0 - with channel key support).
|
||||
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted message bytes
|
||||
photo_data: Reference photo bytes
|
||||
@@ -501,10 +503,10 @@ def decrypt_message(
|
||||
pin: Optional static PIN
|
||||
rsa_key_data: Optional RSA key bytes
|
||||
channel_key: Channel key parameter (see encrypt_message)
|
||||
|
||||
|
||||
Returns:
|
||||
DecodeResult with decrypted content
|
||||
|
||||
|
||||
Raises:
|
||||
InvalidHeaderError: If data doesn't have valid Stegasoo header
|
||||
DecryptionError: If decryption fails (wrong credentials)
|
||||
@@ -512,20 +514,20 @@ def decrypt_message(
|
||||
header = parse_header(encrypted_data)
|
||||
if not header:
|
||||
raise InvalidHeaderError("Invalid or missing Stegasoo header")
|
||||
|
||||
|
||||
# Check for channel key mismatch and provide helpful error
|
||||
channel_hash = _resolve_channel_key(channel_key)
|
||||
has_configured_key = channel_hash is not None
|
||||
message_has_key = header['has_channel_key']
|
||||
|
||||
|
||||
try:
|
||||
key = derive_hybrid_key(
|
||||
photo_data, passphrase, header['salt'], pin, rsa_key_data, channel_key
|
||||
)
|
||||
|
||||
|
||||
# Reconstruct header for AAD verification
|
||||
aad_header = MAGIC_HEADER + bytes([FORMAT_VERSION, header['flags']])
|
||||
|
||||
|
||||
cipher = Cipher(
|
||||
algorithms.AES(key),
|
||||
modes.GCM(header['iv'], header['tag']),
|
||||
@@ -533,15 +535,15 @@ def decrypt_message(
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decryptor.authenticate_additional_data(aad_header)
|
||||
|
||||
|
||||
padded_plaintext = decryptor.update(header['ciphertext']) + decryptor.finalize()
|
||||
original_length = struct.unpack('>I', padded_plaintext[-4:])[0]
|
||||
|
||||
|
||||
payload_data = padded_plaintext[:original_length]
|
||||
result = _unpack_payload(payload_data)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Provide more helpful error message for channel key issues
|
||||
if message_has_key and not has_configured_key:
|
||||
@@ -566,14 +568,14 @@ def decrypt_message_text(
|
||||
photo_data: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Decrypt message and return as text string.
|
||||
|
||||
|
||||
For backward compatibility - returns text content or raises error for files.
|
||||
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted message bytes
|
||||
photo_data: Reference photo bytes
|
||||
@@ -581,15 +583,15 @@ def decrypt_message_text(
|
||||
pin: Optional static PIN
|
||||
rsa_key_data: Optional RSA key bytes
|
||||
channel_key: Channel key parameter
|
||||
|
||||
|
||||
Returns:
|
||||
Decrypted message string
|
||||
|
||||
|
||||
Raises:
|
||||
DecryptionError: If decryption fails or content is a file
|
||||
"""
|
||||
result = decrypt_message(encrypted_data, photo_data, passphrase, pin, rsa_key_data, channel_key)
|
||||
|
||||
|
||||
if result.is_file:
|
||||
if result.file_data:
|
||||
# Try to decode as text
|
||||
@@ -600,7 +602,7 @@ def decrypt_message_text(
|
||||
f"Content is a binary file ({result.filename or 'unnamed'}), not text"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
return result.message or ""
|
||||
|
||||
|
||||
@@ -613,10 +615,10 @@ def has_argon2() -> bool:
|
||||
# CHANNEL KEY UTILITIES (exposed for convenience)
|
||||
# =============================================================================
|
||||
|
||||
def get_active_channel_key() -> Optional[str]:
|
||||
def get_active_channel_key() -> str | None:
|
||||
"""
|
||||
Get the currently configured channel key (if any).
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted channel key string, or None if not configured
|
||||
"""
|
||||
@@ -624,7 +626,7 @@ def get_active_channel_key() -> Optional[str]:
|
||||
return get_channel_key()
|
||||
|
||||
|
||||
def get_channel_fingerprint(key: Optional[str] = None) -> Optional[str]:
|
||||
def get_channel_fingerprint(key: str | None = None) -> str | None:
|
||||
"""
|
||||
Get a display-safe fingerprint of a channel key.
|
||||
|
||||
|
||||
@@ -14,12 +14,11 @@ v3.2.0-patch2 Changes:
|
||||
Requires: scipy (for PNG mode), optionally jpegio (for JPEG mode)
|
||||
"""
|
||||
|
||||
import gc
|
||||
import hashlib
|
||||
import io
|
||||
import struct
|
||||
import hashlib
|
||||
import gc
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
@@ -103,7 +102,7 @@ class DCTEmbedStats:
|
||||
color_mode: str = 'grayscale'
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class DCTCapacityInfo:
|
||||
width: int
|
||||
height: int
|
||||
@@ -147,19 +146,19 @@ def _safe_dct2(block: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
# Create a brand new array (not a view)
|
||||
safe_block = np.array(block, dtype=np.float64, copy=True, order='C')
|
||||
|
||||
|
||||
# First DCT on columns (transpose -> DCT rows -> transpose back)
|
||||
temp = np.zeros_like(safe_block, dtype=np.float64, order='C')
|
||||
for i in range(BLOCK_SIZE):
|
||||
col = np.array(safe_block[:, i], dtype=np.float64, copy=True)
|
||||
temp[:, i] = dct(col, norm='ortho')
|
||||
|
||||
|
||||
# Second DCT on rows
|
||||
result = np.zeros_like(temp, dtype=np.float64, order='C')
|
||||
for i in range(BLOCK_SIZE):
|
||||
row = np.array(temp[i, :], dtype=np.float64, copy=True)
|
||||
result[i, :] = dct(row, norm='ortho')
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -170,19 +169,19 @@ def _safe_idct2(block: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
# Create a brand new array (not a view)
|
||||
safe_block = np.array(block, dtype=np.float64, copy=True, order='C')
|
||||
|
||||
|
||||
# First IDCT on rows
|
||||
temp = np.zeros_like(safe_block, dtype=np.float64, order='C')
|
||||
for i in range(BLOCK_SIZE):
|
||||
row = np.array(safe_block[i, :], dtype=np.float64, copy=True)
|
||||
temp[i, :] = idct(row, norm='ortho')
|
||||
|
||||
|
||||
# Second IDCT on columns
|
||||
result = np.zeros_like(temp, dtype=np.float64, order='C')
|
||||
for i in range(BLOCK_SIZE):
|
||||
col = np.array(temp[:, i], dtype=np.float64, copy=True)
|
||||
result[:, i] = idct(col, norm='ortho')
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -200,23 +199,23 @@ def _extract_y_channel(image_data: bytes) -> np.ndarray:
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
|
||||
rgb = np.array(img, dtype=np.float64, copy=True, order='C')
|
||||
Y = 0.299 * rgb[:, :, 0] + 0.587 * rgb[:, :, 1] + 0.114 * rgb[:, :, 2]
|
||||
return np.array(Y, dtype=np.float64, copy=True, order='C')
|
||||
|
||||
|
||||
def _pad_to_blocks(image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||||
def _pad_to_blocks(image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
|
||||
h, w = image.shape
|
||||
new_h = ((h + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
|
||||
new_w = ((w + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
|
||||
|
||||
|
||||
if new_h == h and new_w == w:
|
||||
return np.array(image, dtype=np.float64, copy=True, order='C'), (h, w)
|
||||
|
||||
|
||||
padded = np.zeros((new_h, new_w), dtype=np.float64, order='C')
|
||||
padded[:h, :w] = image
|
||||
|
||||
|
||||
# Simple edge replication for padding
|
||||
if new_h > h:
|
||||
for i in range(h, new_h):
|
||||
@@ -226,11 +225,11 @@ def _pad_to_blocks(image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||||
padded[:h, j] = padded[:h, w-1]
|
||||
if new_h > h and new_w > w:
|
||||
padded[h:, w:] = padded[h-1, w-1]
|
||||
|
||||
|
||||
return padded, (h, w)
|
||||
|
||||
|
||||
def _unpad_image(image: np.ndarray, original_size: Tuple[int, int]) -> np.ndarray:
|
||||
def _unpad_image(image: np.ndarray, original_size: tuple[int, int]) -> np.ndarray:
|
||||
h, w = original_size
|
||||
return np.array(image[:h, :w], dtype=np.float64, copy=True, order='C')
|
||||
|
||||
@@ -263,7 +262,7 @@ def _save_stego_image(image: np.ndarray, output_format: str = OUTPUT_FORMAT_PNG)
|
||||
img = Image.fromarray(clipped, mode='L')
|
||||
buffer = io.BytesIO()
|
||||
if output_format == OUTPUT_FORMAT_JPEG:
|
||||
img.save(buffer, format='JPEG', quality=JPEG_OUTPUT_QUALITY,
|
||||
img.save(buffer, format='JPEG', quality=JPEG_OUTPUT_QUALITY,
|
||||
subsampling=0, optimize=True)
|
||||
else:
|
||||
img.save(buffer, format='PNG', optimize=True)
|
||||
@@ -282,15 +281,15 @@ def _save_color_image(rgb_array: np.ndarray, output_format: str = OUTPUT_FORMAT_
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def _rgb_to_ycbcr(rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
def _rgb_to_ycbcr(rgb: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
R = rgb[:, :, 0].astype(np.float64)
|
||||
G = rgb[:, :, 1].astype(np.float64)
|
||||
B = rgb[:, :, 2].astype(np.float64)
|
||||
|
||||
|
||||
Y = np.array(0.299 * R + 0.587 * G + 0.114 * B, dtype=np.float64, copy=True, order='C')
|
||||
Cb = np.array(128 - 0.168736 * R - 0.331264 * G + 0.5 * B, dtype=np.float64, copy=True, order='C')
|
||||
Cr = np.array(128 + 0.5 * R - 0.418688 * G - 0.081312 * B, dtype=np.float64, copy=True, order='C')
|
||||
|
||||
|
||||
return Y, Cb, Cr
|
||||
|
||||
|
||||
@@ -298,7 +297,7 @@ def _ycbcr_to_rgb(Y: np.ndarray, Cb: np.ndarray, Cr: np.ndarray) -> np.ndarray:
|
||||
R = Y + 1.402 * (Cr - 128)
|
||||
G = Y - 0.344136 * (Cb - 128) - 0.714136 * (Cr - 128)
|
||||
B = Y + 1.772 * (Cb - 128)
|
||||
|
||||
|
||||
rgb = np.zeros((Y.shape[0], Y.shape[1], 3), dtype=np.float64, order='C')
|
||||
rgb[:, :, 0] = R
|
||||
rgb[:, :, 1] = G
|
||||
@@ -310,20 +309,20 @@ def _create_header(data_length: int, flags: int = 0) -> bytes:
|
||||
return struct.pack('>4sBBI', DCT_MAGIC, 1, flags, data_length)
|
||||
|
||||
|
||||
def _parse_header(header_bits: list) -> Tuple[int, int, int]:
|
||||
def _parse_header(header_bits: list) -> tuple[int, int, int]:
|
||||
if len(header_bits) < HEADER_SIZE * 8:
|
||||
raise ValueError("Insufficient header data")
|
||||
|
||||
|
||||
header_bytes = bytes([
|
||||
sum(header_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
|
||||
for i in range(HEADER_SIZE)
|
||||
])
|
||||
|
||||
|
||||
magic, version, flags, length = struct.unpack('>4sBBI', header_bytes)
|
||||
|
||||
|
||||
if magic != DCT_MAGIC:
|
||||
raise ValueError("Invalid DCT stego magic bytes")
|
||||
|
||||
|
||||
return version, flags, length
|
||||
|
||||
|
||||
@@ -332,8 +331,8 @@ def _parse_header(header_bits: list) -> Tuple[int, int, int]:
|
||||
# ============================================================================
|
||||
|
||||
def _jpegio_bytes_to_file(data: bytes, suffix: str = '.jpg') -> str:
|
||||
import tempfile
|
||||
import os
|
||||
import tempfile
|
||||
fd, path = tempfile.mkstemp(suffix=suffix)
|
||||
try:
|
||||
os.write(fd, data)
|
||||
@@ -366,7 +365,7 @@ def _jpegio_create_header(data_length: int, flags: int = 0) -> bytes:
|
||||
return struct.pack('>4sBBI', JPEGIO_MAGIC, 1, flags, data_length)
|
||||
|
||||
|
||||
def _jpegio_parse_header(header_bytes: bytes) -> Tuple[int, int, int]:
|
||||
def _jpegio_parse_header(header_bytes: bytes) -> tuple[int, int, int]:
|
||||
if len(header_bytes) < HEADER_SIZE:
|
||||
raise ValueError("Insufficient header data")
|
||||
magic, version, flags, length = struct.unpack('>4sBBI', header_bytes[:HEADER_SIZE])
|
||||
@@ -382,21 +381,21 @@ def _jpegio_parse_header(header_bytes: bytes) -> Tuple[int, int, int]:
|
||||
def calculate_dct_capacity(image_data: bytes) -> DCTCapacityInfo:
|
||||
"""Calculate DCT embedding capacity of an image."""
|
||||
_check_scipy()
|
||||
|
||||
|
||||
# Just get dimensions, don't process anything
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
width, height = img.size
|
||||
img.close() # Explicitly close
|
||||
|
||||
|
||||
blocks_x = width // BLOCK_SIZE
|
||||
blocks_y = height // BLOCK_SIZE
|
||||
total_blocks = blocks_x * blocks_y
|
||||
|
||||
|
||||
bits_per_block = len(DEFAULT_EMBED_POSITIONS)
|
||||
total_bits = total_blocks * bits_per_block
|
||||
total_bytes = total_bits // 8
|
||||
usable_bytes = max(0, total_bytes - HEADER_SIZE)
|
||||
|
||||
|
||||
return DCTCapacityInfo(
|
||||
width=width,
|
||||
height=height,
|
||||
@@ -420,13 +419,13 @@ def estimate_capacity_comparison(image_data: bytes) -> dict:
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
width, height = img.size
|
||||
img.close()
|
||||
|
||||
|
||||
pixels = width * height
|
||||
lsb_bytes = (pixels * 3) // 8
|
||||
|
||||
|
||||
blocks = (width // 8) * (height // 8)
|
||||
dct_bytes = (blocks * 16) // 8 - HEADER_SIZE
|
||||
|
||||
|
||||
return {
|
||||
'width': width,
|
||||
'height': height,
|
||||
@@ -455,17 +454,17 @@ def embed_in_dct(
|
||||
seed: bytes,
|
||||
output_format: str = OUTPUT_FORMAT_PNG,
|
||||
color_mode: str = 'color',
|
||||
) -> Tuple[bytes, DCTEmbedStats]:
|
||||
) -> tuple[bytes, DCTEmbedStats]:
|
||||
"""Embed data using DCT coefficient modification."""
|
||||
if output_format not in (OUTPUT_FORMAT_PNG, OUTPUT_FORMAT_JPEG):
|
||||
raise ValueError(f"Invalid output format: {output_format}")
|
||||
|
||||
|
||||
if color_mode not in ('color', 'grayscale'):
|
||||
color_mode = 'color'
|
||||
|
||||
|
||||
if output_format == OUTPUT_FORMAT_JPEG and HAS_JPEGIO:
|
||||
return _embed_jpegio(data, carrier_image, seed, color_mode)
|
||||
|
||||
|
||||
_check_scipy()
|
||||
return _embed_scipy_dct_safe(data, carrier_image, seed, output_format, color_mode)
|
||||
|
||||
@@ -476,27 +475,27 @@ def _embed_scipy_dct_safe(
|
||||
seed: bytes,
|
||||
output_format: str,
|
||||
color_mode: str = 'color',
|
||||
) -> Tuple[bytes, DCTEmbedStats]:
|
||||
) -> tuple[bytes, DCTEmbedStats]:
|
||||
"""
|
||||
Embed using scipy DCT with safe memory handling.
|
||||
|
||||
|
||||
Uses row-by-row 1D DCT operations instead of 2D arrays to avoid
|
||||
scipy memory corruption issues with large images.
|
||||
"""
|
||||
capacity_info = calculate_dct_capacity(carrier_image)
|
||||
|
||||
|
||||
if len(data) > capacity_info.usable_capacity_bytes:
|
||||
raise ValueError(
|
||||
f"Data too large ({len(data)} bytes) for carrier "
|
||||
f"(capacity: {capacity_info.usable_capacity_bytes} bytes)"
|
||||
)
|
||||
|
||||
|
||||
# Load image
|
||||
img = Image.open(io.BytesIO(carrier_image))
|
||||
width, height = img.size
|
||||
|
||||
|
||||
flags = FLAG_COLOR_MODE if color_mode == 'color' else 0
|
||||
|
||||
|
||||
# Prepare payload bits
|
||||
header = _create_header(len(data), flags)
|
||||
payload = header + data
|
||||
@@ -504,41 +503,41 @@ def _embed_scipy_dct_safe(
|
||||
for byte in payload:
|
||||
for i in range(7, -1, -1):
|
||||
bits.append((byte >> i) & 1)
|
||||
|
||||
|
||||
# Generate block order
|
||||
num_blocks = capacity_info.total_blocks
|
||||
block_order = _generate_block_order(num_blocks, seed)
|
||||
blocks_x = width // BLOCK_SIZE
|
||||
|
||||
|
||||
if color_mode == 'color' and img.mode in ('RGB', 'RGBA'):
|
||||
if img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
|
||||
|
||||
# Process color image
|
||||
rgb = np.array(img, dtype=np.float64, copy=True, order='C')
|
||||
img.close()
|
||||
|
||||
|
||||
Y, Cb, Cr = _rgb_to_ycbcr(rgb)
|
||||
del rgb
|
||||
gc.collect()
|
||||
|
||||
|
||||
Y_padded, original_size = _pad_to_blocks(Y)
|
||||
del Y
|
||||
gc.collect()
|
||||
|
||||
|
||||
# Embed in Y channel
|
||||
Y_embedded = _embed_in_channel_safe(Y_padded, bits, block_order, blocks_x)
|
||||
del Y_padded
|
||||
gc.collect()
|
||||
|
||||
|
||||
Y_result = _unpad_image(Y_embedded, original_size)
|
||||
del Y_embedded
|
||||
gc.collect()
|
||||
|
||||
|
||||
result_rgb = _ycbcr_to_rgb(Y_result, Cb, Cr)
|
||||
del Y_result, Cb, Cr
|
||||
gc.collect()
|
||||
|
||||
|
||||
stego_bytes = _save_color_image(result_rgb, output_format)
|
||||
del result_rgb
|
||||
gc.collect()
|
||||
@@ -546,23 +545,23 @@ def _embed_scipy_dct_safe(
|
||||
# Grayscale mode
|
||||
image = _to_grayscale(carrier_image)
|
||||
img.close()
|
||||
|
||||
|
||||
padded, original_size = _pad_to_blocks(image)
|
||||
del image
|
||||
gc.collect()
|
||||
|
||||
|
||||
embedded = _embed_in_channel_safe(padded, bits, block_order, blocks_x)
|
||||
del padded
|
||||
gc.collect()
|
||||
|
||||
|
||||
result = _unpad_image(embedded, original_size)
|
||||
del embedded
|
||||
gc.collect()
|
||||
|
||||
|
||||
stego_bytes = _save_stego_image(result, output_format)
|
||||
del result
|
||||
gc.collect()
|
||||
|
||||
|
||||
stats = DCTEmbedStats(
|
||||
blocks_used=(len(bits) + len(DEFAULT_EMBED_POSITIONS) - 1) // len(DEFAULT_EMBED_POSITIONS),
|
||||
blocks_available=capacity_info.total_blocks,
|
||||
@@ -575,7 +574,7 @@ def _embed_scipy_dct_safe(
|
||||
jpeg_native=False,
|
||||
color_mode=color_mode,
|
||||
)
|
||||
|
||||
|
||||
return stego_bytes, stats
|
||||
|
||||
|
||||
@@ -587,78 +586,78 @@ def _embed_in_channel_safe(
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Embed bits in channel using safe DCT operations.
|
||||
|
||||
|
||||
Processes one block at a time with fresh array allocations.
|
||||
"""
|
||||
h, w = channel.shape
|
||||
|
||||
|
||||
# Create result with explicit new memory
|
||||
result = np.array(channel, dtype=np.float64, copy=True, order='C')
|
||||
|
||||
|
||||
bit_idx = 0
|
||||
|
||||
|
||||
for block_num in block_order:
|
||||
if bit_idx >= len(bits):
|
||||
break
|
||||
|
||||
|
||||
by = (block_num // blocks_x) * BLOCK_SIZE
|
||||
bx = (block_num % blocks_x) * BLOCK_SIZE
|
||||
|
||||
|
||||
# Extract block - create brand new array
|
||||
block = np.array(
|
||||
result[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE],
|
||||
dtype=np.float64, copy=True, order='C'
|
||||
)
|
||||
|
||||
|
||||
# Apply safe DCT (row-by-row)
|
||||
dct_block = _safe_dct2(block)
|
||||
|
||||
|
||||
# Embed bits
|
||||
for pos in DEFAULT_EMBED_POSITIONS:
|
||||
if bit_idx >= len(bits):
|
||||
break
|
||||
dct_block[pos[0], pos[1]] = _embed_bit_in_coeff(
|
||||
float(dct_block[pos[0], pos[1]]),
|
||||
float(dct_block[pos[0], pos[1]]),
|
||||
bits[bit_idx]
|
||||
)
|
||||
bit_idx += 1
|
||||
|
||||
|
||||
# Apply safe inverse DCT
|
||||
modified_block = _safe_idct2(dct_block)
|
||||
|
||||
|
||||
# Copy back
|
||||
result[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE] = modified_block
|
||||
|
||||
|
||||
# Clean up this iteration
|
||||
del block, dct_block, modified_block
|
||||
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_jpeg_for_jpegio(image_data: bytes) -> bytes:
|
||||
"""
|
||||
Normalize a JPEG image to ensure jpegio can process it safely.
|
||||
|
||||
|
||||
JPEGs saved with quality=100 have quantization tables with all values = 1,
|
||||
which causes jpegio to crash due to huge coefficient magnitudes.
|
||||
This function detects such images and re-saves them at a safe quality level.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Raw JPEG bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Normalized JPEG bytes (may be unchanged if already safe)
|
||||
"""
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
|
||||
# Only process JPEGs
|
||||
if img.format != 'JPEG':
|
||||
img.close()
|
||||
return image_data
|
||||
|
||||
|
||||
# Check quantization tables
|
||||
needs_normalization = False
|
||||
if hasattr(img, 'quantization') and img.quantization:
|
||||
@@ -667,19 +666,19 @@ def _normalize_jpeg_for_jpegio(image_data: bytes) -> bytes:
|
||||
if max(table) <= JPEGIO_MAX_QUANT_VALUE_THRESHOLD:
|
||||
needs_normalization = True
|
||||
break
|
||||
|
||||
|
||||
if not needs_normalization:
|
||||
img.close()
|
||||
return image_data
|
||||
|
||||
|
||||
# Re-save at safe quality level
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=JPEGIO_NORMALIZE_QUALITY, subsampling=0)
|
||||
img.close()
|
||||
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
@@ -688,17 +687,17 @@ def _embed_jpegio(
|
||||
carrier_image: bytes,
|
||||
seed: bytes,
|
||||
color_mode: str = 'color',
|
||||
) -> Tuple[bytes, DCTEmbedStats]:
|
||||
) -> tuple[bytes, DCTEmbedStats]:
|
||||
"""Embed using jpegio for proper JPEG coefficient modification."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
import tempfile
|
||||
|
||||
# Normalize JPEG to avoid crashes with quality=100 images
|
||||
carrier_image = _normalize_jpeg_for_jpegio(carrier_image)
|
||||
|
||||
|
||||
img = Image.open(io.BytesIO(carrier_image))
|
||||
width, height = img.size
|
||||
|
||||
|
||||
if img.format != 'JPEG':
|
||||
buffer = io.BytesIO()
|
||||
if img.mode != 'RGB':
|
||||
@@ -706,54 +705,54 @@ def _embed_jpegio(
|
||||
img.save(buffer, format='JPEG', quality=95, subsampling=0)
|
||||
carrier_image = buffer.getvalue()
|
||||
img.close()
|
||||
|
||||
|
||||
input_path = _jpegio_bytes_to_file(carrier_image, suffix='.jpg')
|
||||
output_path = tempfile.mktemp(suffix='.jpg')
|
||||
|
||||
|
||||
flags = FLAG_COLOR_MODE if color_mode == 'color' else 0
|
||||
|
||||
|
||||
try:
|
||||
jpeg = jio.read(input_path)
|
||||
coef_array = jpeg.coef_arrays[JPEGIO_EMBED_CHANNEL]
|
||||
|
||||
|
||||
all_positions = _jpegio_get_usable_positions(coef_array)
|
||||
order = _jpegio_generate_order(len(all_positions), seed)
|
||||
|
||||
|
||||
header = _jpegio_create_header(len(data), flags)
|
||||
payload = header + data
|
||||
|
||||
|
||||
bits = []
|
||||
for byte in payload:
|
||||
for i in range(7, -1, -1):
|
||||
bits.append((byte >> i) & 1)
|
||||
|
||||
|
||||
if len(bits) > len(all_positions):
|
||||
raise ValueError(
|
||||
f"Payload too large: {len(bits)} bits, "
|
||||
f"only {len(all_positions)} usable coefficients"
|
||||
)
|
||||
|
||||
|
||||
coefs_used = 0
|
||||
for bit_idx, pos_idx in enumerate(order):
|
||||
if bit_idx >= len(bits):
|
||||
break
|
||||
|
||||
|
||||
row, col = all_positions[pos_idx]
|
||||
coef = coef_array[row, col]
|
||||
|
||||
|
||||
if (coef & 1) != bits[bit_idx]:
|
||||
if coef > 0:
|
||||
coef_array[row, col] = coef - 1 if (coef & 1) else coef + 1
|
||||
else:
|
||||
coef_array[row, col] = coef + 1 if (coef & 1) else coef - 1
|
||||
|
||||
|
||||
coefs_used += 1
|
||||
|
||||
|
||||
jio.write(jpeg, output_path)
|
||||
|
||||
|
||||
with open(output_path, 'rb') as f:
|
||||
stego_bytes = f.read()
|
||||
|
||||
|
||||
stats = DCTEmbedStats(
|
||||
blocks_used=coefs_used // 63,
|
||||
blocks_available=len(all_positions) // 63,
|
||||
@@ -766,9 +765,9 @@ def _embed_jpegio(
|
||||
jpeg_native=True,
|
||||
color_mode=color_mode,
|
||||
)
|
||||
|
||||
|
||||
return stego_bytes, stats
|
||||
|
||||
|
||||
finally:
|
||||
for path in [input_path, output_path]:
|
||||
try:
|
||||
@@ -782,13 +781,13 @@ def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes:
|
||||
img = Image.open(io.BytesIO(stego_image))
|
||||
fmt = img.format
|
||||
img.close()
|
||||
|
||||
|
||||
if fmt == 'JPEG' and HAS_JPEGIO:
|
||||
try:
|
||||
return _extract_jpegio(stego_image, seed)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
_check_scipy()
|
||||
return _extract_scipy_dct_safe(stego_image, seed)
|
||||
|
||||
@@ -798,41 +797,41 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
|
||||
img = Image.open(io.BytesIO(stego_image))
|
||||
width, height = img.size
|
||||
mode = img.mode
|
||||
|
||||
|
||||
if mode in ('RGB', 'RGBA'):
|
||||
channel = _extract_y_channel(stego_image)
|
||||
else:
|
||||
channel = _to_grayscale(stego_image)
|
||||
img.close()
|
||||
|
||||
|
||||
padded, _ = _pad_to_blocks(channel)
|
||||
del channel
|
||||
gc.collect()
|
||||
|
||||
|
||||
h, w = padded.shape
|
||||
blocks_x = w // BLOCK_SIZE
|
||||
num_blocks = (h // BLOCK_SIZE) * blocks_x
|
||||
|
||||
|
||||
block_order = _generate_block_order(num_blocks, seed)
|
||||
|
||||
|
||||
all_bits = []
|
||||
|
||||
|
||||
for block_num in block_order:
|
||||
by = (block_num // blocks_x) * BLOCK_SIZE
|
||||
bx = (block_num % blocks_x) * BLOCK_SIZE
|
||||
|
||||
|
||||
block = np.array(
|
||||
padded[by:by+BLOCK_SIZE, bx:bx+BLOCK_SIZE],
|
||||
dtype=np.float64, copy=True, order='C'
|
||||
)
|
||||
dct_block = _safe_dct2(block)
|
||||
|
||||
|
||||
for pos in DEFAULT_EMBED_POSITIONS:
|
||||
bit = _extract_bit_from_coeff(float(dct_block[pos[0], pos[1]]))
|
||||
all_bits.append(bit)
|
||||
|
||||
|
||||
del block, dct_block
|
||||
|
||||
|
||||
if len(all_bits) >= HEADER_SIZE * 8:
|
||||
try:
|
||||
_, flags, data_length = _parse_header(all_bits[:HEADER_SIZE * 8])
|
||||
@@ -841,53 +840,53 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
del padded
|
||||
gc.collect()
|
||||
|
||||
|
||||
_, flags, data_length = _parse_header(all_bits)
|
||||
data_bits = all_bits[HEADER_SIZE * 8:(HEADER_SIZE + data_length) * 8]
|
||||
|
||||
|
||||
data = bytes([
|
||||
sum(data_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
|
||||
for i in range(data_length)
|
||||
])
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
|
||||
"""Extract using jpegio for JPEG images."""
|
||||
import os
|
||||
|
||||
|
||||
# Normalize JPEG to avoid crashes with quality=100 images
|
||||
# (shouldn't happen with stego images, but be defensive)
|
||||
stego_image = _normalize_jpeg_for_jpegio(stego_image)
|
||||
|
||||
|
||||
temp_path = _jpegio_bytes_to_file(stego_image, suffix='.jpg')
|
||||
|
||||
|
||||
try:
|
||||
jpeg = jio.read(temp_path)
|
||||
coef_array = jpeg.coef_arrays[JPEGIO_EMBED_CHANNEL]
|
||||
|
||||
|
||||
all_positions = _jpegio_get_usable_positions(coef_array)
|
||||
order = _jpegio_generate_order(len(all_positions), seed)
|
||||
|
||||
|
||||
header_bits = []
|
||||
for pos_idx in order[:HEADER_SIZE * 8]:
|
||||
row, col = all_positions[pos_idx]
|
||||
coef = coef_array[row, col]
|
||||
header_bits.append(coef & 1)
|
||||
|
||||
|
||||
header_bytes = bytes([
|
||||
sum(header_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
|
||||
for i in range(HEADER_SIZE)
|
||||
])
|
||||
|
||||
|
||||
_, flags, data_length = _jpegio_parse_header(header_bytes)
|
||||
|
||||
|
||||
total_bits_needed = (HEADER_SIZE + data_length) * 8
|
||||
|
||||
|
||||
all_bits = []
|
||||
for bit_idx, pos_idx in enumerate(order):
|
||||
if bit_idx >= total_bits_needed:
|
||||
@@ -895,16 +894,16 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
|
||||
row, col = all_positions[pos_idx]
|
||||
coef = coef_array[row, col]
|
||||
all_bits.append(coef & 1)
|
||||
|
||||
|
||||
data_bits = all_bits[HEADER_SIZE * 8:]
|
||||
|
||||
|
||||
data = bytes([
|
||||
sum(data_bits[i*8:(i+1)*8][j] << (7-j) for j in range(8))
|
||||
for i in range(data_length)
|
||||
])
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
|
||||
@@ -5,12 +5,13 @@ Debugging, logging, and performance monitoring tools.
|
||||
Can be disabled for production use.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Callable, Any, Optional, Dict, Union
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# Global debug configuration
|
||||
DEBUG_ENABLED = False # Set to True to enable debug output
|
||||
@@ -47,10 +48,10 @@ def debug_data(data: bytes, label: str = "Data", max_bytes: int = 32) -> str:
|
||||
"""Format bytes for debugging."""
|
||||
if not DEBUG_ENABLED:
|
||||
return ""
|
||||
|
||||
|
||||
if not data:
|
||||
return f"{label}: Empty"
|
||||
|
||||
|
||||
if len(data) <= max_bytes:
|
||||
return f"{label} ({len(data)} bytes): {data.hex()}"
|
||||
else:
|
||||
@@ -71,7 +72,7 @@ def time_function(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
if not (DEBUG_ENABLED and LOG_PERFORMANCE):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
@@ -79,7 +80,7 @@ def time_function(func: Callable) -> Callable:
|
||||
finally:
|
||||
end = time.perf_counter()
|
||||
debug_print(f"{func.__name__} took {end - start:.6f}s", "PERF")
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -89,14 +90,15 @@ def validate_assertion(condition: bool, message: str) -> None:
|
||||
raise AssertionError(f"Validation failed: {message}")
|
||||
|
||||
|
||||
def memory_usage() -> Dict[str, Union[float, str]]:
|
||||
def memory_usage() -> dict[str, float | str]:
|
||||
"""Get current memory usage (if psutil is available)."""
|
||||
try:
|
||||
import psutil
|
||||
import os
|
||||
|
||||
import psutil
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
|
||||
|
||||
return {
|
||||
'rss_mb': mem_info.rss / 1024 / 1024,
|
||||
'vms_mb': mem_info.vms / 1024 / 1024,
|
||||
@@ -110,66 +112,66 @@ def hexdump(data: bytes, offset: int = 0, length: int = 64) -> str:
|
||||
"""Create hexdump string for debugging binary data."""
|
||||
if not data:
|
||||
return "Empty"
|
||||
|
||||
|
||||
result = []
|
||||
data_to_dump = data[:length]
|
||||
|
||||
|
||||
for i in range(0, len(data_to_dump), 16):
|
||||
chunk = data_to_dump[i:i+16]
|
||||
hex_str = ' '.join(f'{b:02x}' for b in chunk)
|
||||
hex_str = hex_str.ljust(47)
|
||||
ascii_str = ''.join(chr(b) if 32 <= b < 127 else '.' for b in chunk)
|
||||
result.append(f"{offset + i:08x}: {hex_str} {ascii_str}")
|
||||
|
||||
|
||||
if len(data) > length:
|
||||
result.append(f"... ({len(data) - length} more bytes)")
|
||||
|
||||
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
class Debug:
|
||||
"""Debugging utility class."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = DEBUG_ENABLED
|
||||
|
||||
|
||||
def print(self, message: str, level: str = "INFO") -> None:
|
||||
"""Print debug message."""
|
||||
debug_print(message, level)
|
||||
|
||||
|
||||
def data(self, data: bytes, label: str = "Data", max_bytes: int = 32) -> str:
|
||||
"""Format bytes for debugging."""
|
||||
return debug_data(data, label, max_bytes)
|
||||
|
||||
|
||||
def exception(self, e: Exception, context: str = "") -> None:
|
||||
"""Log exception with context."""
|
||||
debug_exception(e, context)
|
||||
|
||||
|
||||
def time(self, func: Callable) -> Callable:
|
||||
"""Decorator to time function execution."""
|
||||
return time_function(func)
|
||||
|
||||
|
||||
def validate(self, condition: bool, message: str) -> None:
|
||||
"""Runtime validation assertion."""
|
||||
validate_assertion(condition, message)
|
||||
|
||||
def memory(self) -> Dict[str, Union[float, str]]:
|
||||
|
||||
def memory(self) -> dict[str, float | str]:
|
||||
"""Get current memory usage."""
|
||||
return memory_usage()
|
||||
|
||||
|
||||
def hexdump(self, data: bytes, offset: int = 0, length: int = 64) -> str:
|
||||
"""Create hexdump string."""
|
||||
return hexdump(data, offset, length)
|
||||
|
||||
|
||||
def enable(self, enable: bool = True) -> None:
|
||||
"""Enable or disable debug mode."""
|
||||
enable_debug(enable)
|
||||
self.enabled = enable
|
||||
|
||||
|
||||
def enable_performance(self, enable: bool = True) -> None:
|
||||
"""Enable or disable performance logging."""
|
||||
enable_performance_logging(enable)
|
||||
|
||||
|
||||
def enable_assertions(self, enable: bool = True) -> None:
|
||||
"""Enable or disable validation assertions."""
|
||||
enable_assertions(enable)
|
||||
|
||||
@@ -8,21 +8,20 @@ Changes in v4.0.0:
|
||||
- Improved error messages for channel key mismatches
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
from .models import DecodeInput, DecodeResult
|
||||
from .constants import EMBED_MODE_AUTO
|
||||
from .crypto import decrypt_message
|
||||
from .debug import debug
|
||||
from .exceptions import DecryptionError, ExtractionError
|
||||
from .models import DecodeResult
|
||||
from .steganography import extract_from_image
|
||||
from .validation import (
|
||||
require_valid_image,
|
||||
require_security_factors,
|
||||
require_valid_image,
|
||||
require_valid_pin,
|
||||
require_valid_rsa_key,
|
||||
)
|
||||
from .constants import EMBED_MODE_AUTO
|
||||
from .exceptions import ExtractionError, DecryptionError
|
||||
from .debug import debug
|
||||
|
||||
|
||||
def decode(
|
||||
@@ -30,14 +29,14 @@ def decode(
|
||||
reference_photo: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
rsa_password: Optional[str] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
rsa_password: str | None = None,
|
||||
embed_mode: str = EMBED_MODE_AUTO,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> DecodeResult:
|
||||
"""
|
||||
Decode a message or file from a stego image.
|
||||
|
||||
|
||||
Args:
|
||||
stego_image: Stego image bytes
|
||||
reference_photo: Shared reference photo bytes
|
||||
@@ -50,10 +49,10 @@ def decode(
|
||||
- None or "auto": Use server's configured key
|
||||
- str: Use this specific channel key
|
||||
- "" or False: No channel key (public mode)
|
||||
|
||||
|
||||
Returns:
|
||||
DecodeResult with message or file data
|
||||
|
||||
|
||||
Example:
|
||||
>>> result = decode(
|
||||
... stego_image=stego_bytes,
|
||||
@@ -66,7 +65,7 @@ def decode(
|
||||
... else:
|
||||
... with open(result.filename, 'wb') as f:
|
||||
... f.write(result.file_data)
|
||||
|
||||
|
||||
Example with explicit channel key:
|
||||
>>> result = decode(
|
||||
... stego_image=stego_bytes,
|
||||
@@ -79,41 +78,41 @@ def decode(
|
||||
debug.print(f"decode: passphrase length={len(passphrase.split())} words, "
|
||||
f"mode={embed_mode}, "
|
||||
f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}")
|
||||
|
||||
|
||||
# Validate inputs
|
||||
require_valid_image(stego_image, "Stego image")
|
||||
require_valid_image(reference_photo, "Reference photo")
|
||||
require_security_factors(pin, rsa_key_data)
|
||||
|
||||
|
||||
if pin:
|
||||
require_valid_pin(pin)
|
||||
if rsa_key_data:
|
||||
require_valid_rsa_key(rsa_key_data, rsa_password)
|
||||
|
||||
|
||||
# Derive pixel/coefficient selection key (with channel key)
|
||||
from .crypto import derive_pixel_key
|
||||
pixel_key = derive_pixel_key(
|
||||
reference_photo, passphrase, pin, rsa_key_data, channel_key
|
||||
)
|
||||
|
||||
|
||||
# Extract encrypted data
|
||||
encrypted = extract_from_image(
|
||||
stego_image,
|
||||
pixel_key,
|
||||
embed_mode=embed_mode,
|
||||
)
|
||||
|
||||
|
||||
if not encrypted:
|
||||
debug.print("No data extracted from image")
|
||||
raise ExtractionError("Could not extract data. Check your credentials and image.")
|
||||
|
||||
|
||||
debug.print(f"Extracted {len(encrypted)} bytes from image")
|
||||
|
||||
|
||||
# Decrypt (with channel key)
|
||||
result = decrypt_message(
|
||||
encrypted, reference_photo, passphrase, pin, rsa_key_data, channel_key
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"Decryption successful: {result.payload_type}")
|
||||
return result
|
||||
|
||||
@@ -122,16 +121,16 @@ def decode_file(
|
||||
stego_image: bytes,
|
||||
reference_photo: bytes,
|
||||
passphrase: str,
|
||||
output_path: Optional[Path] = None,
|
||||
output_path: Path | None = None,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
rsa_password: Optional[str] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
rsa_password: str | None = None,
|
||||
embed_mode: str = EMBED_MODE_AUTO,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> Path:
|
||||
"""
|
||||
Decode a file from a stego image and save it.
|
||||
|
||||
|
||||
Args:
|
||||
stego_image: Stego image bytes
|
||||
reference_photo: Shared reference photo bytes
|
||||
@@ -142,10 +141,10 @@ def decode_file(
|
||||
rsa_password: Optional RSA key password
|
||||
embed_mode: 'auto', 'lsb', or 'dct'
|
||||
channel_key: Channel key parameter (see decode())
|
||||
|
||||
|
||||
Returns:
|
||||
Path where file was saved
|
||||
|
||||
|
||||
Raises:
|
||||
DecryptionError: If payload is text, not a file
|
||||
"""
|
||||
@@ -159,20 +158,20 @@ def decode_file(
|
||||
embed_mode,
|
||||
channel_key,
|
||||
)
|
||||
|
||||
|
||||
if not result.is_file:
|
||||
raise DecryptionError("Payload is a text message, not a file")
|
||||
|
||||
|
||||
if output_path is None:
|
||||
output_path = Path(result.filename or "extracted_file")
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
if output_path.is_dir():
|
||||
output_path = output_path / (result.filename or "extracted_file")
|
||||
|
||||
|
||||
# Write file
|
||||
output_path.write_bytes(result.file_data or b"")
|
||||
|
||||
|
||||
debug.print(f"File saved to: {output_path}")
|
||||
return output_path
|
||||
|
||||
@@ -182,16 +181,16 @@ def decode_text(
|
||||
reference_photo: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
rsa_password: Optional[str] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
rsa_password: str | None = None,
|
||||
embed_mode: str = EMBED_MODE_AUTO,
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Decode a text message from a stego image.
|
||||
|
||||
|
||||
Convenience function that returns just the message string.
|
||||
|
||||
|
||||
Args:
|
||||
stego_image: Stego image bytes
|
||||
reference_photo: Shared reference photo bytes
|
||||
@@ -201,10 +200,10 @@ def decode_text(
|
||||
rsa_password: Optional RSA key password
|
||||
embed_mode: 'auto', 'lsb', or 'dct'
|
||||
channel_key: Channel key parameter (see decode())
|
||||
|
||||
|
||||
Returns:
|
||||
Decoded message string
|
||||
|
||||
|
||||
Raises:
|
||||
DecryptionError: If payload is a file, not text
|
||||
"""
|
||||
@@ -218,7 +217,7 @@ def decode_text(
|
||||
embed_mode,
|
||||
channel_key,
|
||||
)
|
||||
|
||||
|
||||
if result.is_file:
|
||||
# Try to decode as text
|
||||
if result.file_data:
|
||||
@@ -229,5 +228,5 @@ def decode_text(
|
||||
f"Payload is a binary file ({result.filename or 'unnamed'}), not text"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
return result.message or ""
|
||||
|
||||
@@ -7,41 +7,40 @@ Changes in v4.0.0:
|
||||
- Added channel_key parameter for deployment/group isolation
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
from .models import EncodeInput, EncodeResult, FilePayload
|
||||
from .crypto import encrypt_message, derive_pixel_key
|
||||
from .constants import EMBED_MODE_LSB
|
||||
from .crypto import derive_pixel_key, encrypt_message
|
||||
from .debug import debug
|
||||
from .models import EncodeResult, FilePayload
|
||||
from .steganography import embed_in_image
|
||||
from .utils import generate_filename
|
||||
from .validation import (
|
||||
require_valid_payload,
|
||||
require_valid_image,
|
||||
require_security_factors,
|
||||
require_valid_image,
|
||||
require_valid_payload,
|
||||
require_valid_pin,
|
||||
require_valid_rsa_key,
|
||||
)
|
||||
from .utils import generate_filename
|
||||
from .constants import EMBED_MODE_LSB
|
||||
from .debug import debug
|
||||
|
||||
|
||||
def encode(
|
||||
message: Union[str, bytes, FilePayload],
|
||||
message: str | bytes | FilePayload,
|
||||
reference_photo: bytes,
|
||||
carrier_image: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
rsa_password: Optional[str] = None,
|
||||
output_format: Optional[str] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
rsa_password: str | None = None,
|
||||
output_format: str | None = None,
|
||||
embed_mode: str = EMBED_MODE_LSB,
|
||||
dct_output_format: str = "png",
|
||||
dct_color_mode: str = "grayscale",
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> EncodeResult:
|
||||
"""
|
||||
Encode a message or file into an image.
|
||||
|
||||
|
||||
Args:
|
||||
message: Text message, raw bytes, or FilePayload to hide
|
||||
reference_photo: Shared reference photo bytes
|
||||
@@ -58,10 +57,10 @@ def encode(
|
||||
- None or "auto": Use server's configured key
|
||||
- str: Use this specific channel key
|
||||
- "" or False: No channel key (public mode)
|
||||
|
||||
|
||||
Returns:
|
||||
EncodeResult with stego image and metadata
|
||||
|
||||
|
||||
Example:
|
||||
>>> result = encode(
|
||||
... message="Secret message",
|
||||
@@ -72,7 +71,7 @@ def encode(
|
||||
... )
|
||||
>>> with open('stego.png', 'wb') as f:
|
||||
... f.write(result.stego_image)
|
||||
|
||||
|
||||
Example with explicit channel key:
|
||||
>>> result = encode(
|
||||
... message="Secret message",
|
||||
@@ -86,30 +85,30 @@ def encode(
|
||||
debug.print(f"encode: passphrase length={len(passphrase.split())} words, "
|
||||
f"pin={'set' if pin else 'none'}, mode={embed_mode}, "
|
||||
f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}")
|
||||
|
||||
|
||||
# Validate inputs
|
||||
require_valid_payload(message)
|
||||
require_valid_image(reference_photo, "Reference photo")
|
||||
require_valid_image(carrier_image, "Carrier image")
|
||||
require_security_factors(pin, rsa_key_data)
|
||||
|
||||
|
||||
if pin:
|
||||
require_valid_pin(pin)
|
||||
if rsa_key_data:
|
||||
require_valid_rsa_key(rsa_key_data, rsa_password)
|
||||
|
||||
|
||||
# Encrypt message (with channel key)
|
||||
encrypted = encrypt_message(
|
||||
message, reference_photo, passphrase, pin, rsa_key_data, channel_key
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"Encrypted payload: {len(encrypted)} bytes")
|
||||
|
||||
|
||||
# Derive pixel/coefficient selection key (with channel key)
|
||||
pixel_key = derive_pixel_key(
|
||||
reference_photo, passphrase, pin, rsa_key_data, channel_key
|
||||
)
|
||||
|
||||
|
||||
# Embed in image
|
||||
stego_data, stats, extension = embed_in_image(
|
||||
encrypted,
|
||||
@@ -120,10 +119,10 @@ def encode(
|
||||
dct_output_format=dct_output_format,
|
||||
dct_color_mode=dct_color_mode,
|
||||
)
|
||||
|
||||
|
||||
# Generate filename
|
||||
filename = generate_filename(extension=extension)
|
||||
|
||||
|
||||
# Create result
|
||||
if hasattr(stats, 'pixels_modified'):
|
||||
# LSB mode stats
|
||||
@@ -148,25 +147,25 @@ def encode(
|
||||
|
||||
|
||||
def encode_file(
|
||||
filepath: Union[str, Path],
|
||||
filepath: str | Path,
|
||||
reference_photo: bytes,
|
||||
carrier_image: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
rsa_password: Optional[str] = None,
|
||||
output_format: Optional[str] = None,
|
||||
filename_override: Optional[str] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
rsa_password: str | None = None,
|
||||
output_format: str | None = None,
|
||||
filename_override: str | None = None,
|
||||
embed_mode: str = EMBED_MODE_LSB,
|
||||
dct_output_format: str = "png",
|
||||
dct_color_mode: str = "grayscale",
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> EncodeResult:
|
||||
"""
|
||||
Encode a file into an image.
|
||||
|
||||
|
||||
Convenience wrapper that loads a file and encodes it.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Path to file to embed
|
||||
reference_photo: Shared reference photo bytes
|
||||
@@ -181,12 +180,12 @@ def encode_file(
|
||||
dct_output_format: 'png' or 'jpeg'
|
||||
dct_color_mode: 'grayscale' or 'color'
|
||||
channel_key: Channel key parameter (see encode())
|
||||
|
||||
|
||||
Returns:
|
||||
EncodeResult
|
||||
"""
|
||||
payload = FilePayload.from_file(str(filepath), filename_override)
|
||||
|
||||
|
||||
return encode(
|
||||
message=payload,
|
||||
reference_photo=reference_photo,
|
||||
@@ -210,18 +209,18 @@ def encode_bytes(
|
||||
carrier_image: bytes,
|
||||
passphrase: str,
|
||||
pin: str = "",
|
||||
rsa_key_data: Optional[bytes] = None,
|
||||
rsa_password: Optional[str] = None,
|
||||
output_format: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
rsa_key_data: bytes | None = None,
|
||||
rsa_password: str | None = None,
|
||||
output_format: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
embed_mode: str = EMBED_MODE_LSB,
|
||||
dct_output_format: str = "png",
|
||||
dct_color_mode: str = "grayscale",
|
||||
channel_key: Optional[Union[str, bool]] = None,
|
||||
channel_key: str | bool | None = None,
|
||||
) -> EncodeResult:
|
||||
"""
|
||||
Encode raw bytes with metadata into an image.
|
||||
|
||||
|
||||
Args:
|
||||
data: Raw bytes to embed
|
||||
filename: Filename to associate with data
|
||||
@@ -237,12 +236,12 @@ def encode_bytes(
|
||||
dct_output_format: 'png' or 'jpeg'
|
||||
dct_color_mode: 'grayscale' or 'color'
|
||||
channel_key: Channel key parameter (see encode())
|
||||
|
||||
|
||||
Returns:
|
||||
EncodeResult
|
||||
"""
|
||||
payload = FilePayload(data=data, filename=filename, mime_type=mime_type)
|
||||
|
||||
|
||||
return encode(
|
||||
message=payload,
|
||||
reference_photo=reference_photo,
|
||||
|
||||
@@ -89,7 +89,7 @@ class SteganographyError(StegasooError):
|
||||
|
||||
class CapacityError(SteganographyError):
|
||||
"""Carrier image too small for message."""
|
||||
|
||||
|
||||
def __init__(self, needed: int, available: int):
|
||||
self.needed = needed
|
||||
self.available = available
|
||||
@@ -129,7 +129,7 @@ class FileNotFoundError(FileError):
|
||||
|
||||
class FileTooLargeError(FileError):
|
||||
"""File exceeds size limit."""
|
||||
|
||||
|
||||
def __init__(self, size: int, limit: int, filename: str = "File"):
|
||||
self.size = size
|
||||
self.limit = limit
|
||||
@@ -141,7 +141,7 @@ class FileTooLargeError(FileError):
|
||||
|
||||
class UnsupportedFileTypeError(FileError):
|
||||
"""File type not supported."""
|
||||
|
||||
|
||||
def __init__(self, extension: str, allowed: set[str]):
|
||||
self.extension = extension
|
||||
self.allowed = allowed
|
||||
|
||||
@@ -4,28 +4,30 @@ Stegasoo Generate Module (v3.2.0)
|
||||
Public API for generating credentials (PINs, passphrases, RSA keys).
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .keygen import (
|
||||
generate_pin as _generate_pin,
|
||||
generate_phrase,
|
||||
generate_rsa_key as _generate_rsa_key,
|
||||
export_rsa_key_pem,
|
||||
load_rsa_key,
|
||||
)
|
||||
from .models import Credentials
|
||||
from .constants import (
|
||||
DEFAULT_PIN_LENGTH,
|
||||
DEFAULT_PASSPHRASE_WORDS,
|
||||
DEFAULT_PIN_LENGTH,
|
||||
DEFAULT_RSA_BITS,
|
||||
)
|
||||
from .debug import debug
|
||||
|
||||
from .keygen import (
|
||||
export_rsa_key_pem,
|
||||
generate_phrase,
|
||||
load_rsa_key,
|
||||
)
|
||||
from .keygen import (
|
||||
generate_pin as _generate_pin,
|
||||
)
|
||||
from .keygen import (
|
||||
generate_rsa_key as _generate_rsa_key,
|
||||
)
|
||||
from .models import Credentials
|
||||
|
||||
# Re-export from keygen for convenience
|
||||
__all__ = [
|
||||
'generate_pin',
|
||||
'generate_passphrase',
|
||||
'generate_passphrase',
|
||||
'generate_rsa_key',
|
||||
'generate_credentials',
|
||||
'export_rsa_key_pem',
|
||||
@@ -36,15 +38,15 @@ __all__ = [
|
||||
def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
|
||||
"""
|
||||
Generate a random PIN.
|
||||
|
||||
|
||||
PINs never start with zero for usability.
|
||||
|
||||
|
||||
Args:
|
||||
length: PIN length (6-9 digits, default 6)
|
||||
|
||||
|
||||
Returns:
|
||||
PIN string
|
||||
|
||||
|
||||
Example:
|
||||
>>> pin = generate_pin()
|
||||
>>> len(pin)
|
||||
@@ -58,16 +60,16 @@ def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
|
||||
def generate_passphrase(words: int = DEFAULT_PASSPHRASE_WORDS) -> str:
|
||||
"""
|
||||
Generate a random passphrase from BIP-39 wordlist.
|
||||
|
||||
|
||||
In v3.2.0, this generates a single passphrase (not daily phrases).
|
||||
Default is 4 words for good security (increased from 3 in v3.1.0).
|
||||
|
||||
|
||||
Args:
|
||||
words: Number of words (3-12, default 4)
|
||||
|
||||
|
||||
Returns:
|
||||
Space-separated passphrase
|
||||
|
||||
|
||||
Example:
|
||||
>>> passphrase = generate_passphrase(4)
|
||||
>>> len(passphrase.split())
|
||||
@@ -78,18 +80,18 @@ def generate_passphrase(words: int = DEFAULT_PASSPHRASE_WORDS) -> str:
|
||||
|
||||
def generate_rsa_key(
|
||||
bits: int = DEFAULT_RSA_BITS,
|
||||
password: Optional[str] = None
|
||||
password: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate an RSA private key in PEM format.
|
||||
|
||||
|
||||
Args:
|
||||
bits: Key size (2048, 3072, or 4096, default 2048)
|
||||
password: Optional password to encrypt the key
|
||||
|
||||
|
||||
Returns:
|
||||
PEM-encoded key string
|
||||
|
||||
|
||||
Example:
|
||||
>>> key_pem = generate_rsa_key(2048)
|
||||
>>> '-----BEGIN PRIVATE KEY-----' in key_pem
|
||||
@@ -106,14 +108,14 @@ def generate_credentials(
|
||||
pin_length: int = DEFAULT_PIN_LENGTH,
|
||||
rsa_bits: int = DEFAULT_RSA_BITS,
|
||||
passphrase_words: int = DEFAULT_PASSPHRASE_WORDS,
|
||||
rsa_password: Optional[str] = None,
|
||||
rsa_password: str | None = None,
|
||||
) -> Credentials:
|
||||
"""
|
||||
Generate a complete set of credentials.
|
||||
|
||||
|
||||
In v3.2.0, this generates a single passphrase (not daily phrases).
|
||||
At least one of use_pin or use_rsa must be True.
|
||||
|
||||
|
||||
Args:
|
||||
use_pin: Whether to generate a PIN
|
||||
use_rsa: Whether to generate an RSA key
|
||||
@@ -121,13 +123,13 @@ def generate_credentials(
|
||||
rsa_bits: RSA key size (default 2048)
|
||||
passphrase_words: Number of words in passphrase (default 4)
|
||||
rsa_password: Optional password for RSA key
|
||||
|
||||
|
||||
Returns:
|
||||
Credentials object with passphrase, PIN, and/or RSA key
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If neither PIN nor RSA is selected
|
||||
|
||||
|
||||
Example:
|
||||
>>> creds = generate_credentials(use_pin=True, use_rsa=False)
|
||||
>>> len(creds.passphrase.split())
|
||||
@@ -137,23 +139,23 @@ def generate_credentials(
|
||||
"""
|
||||
if not use_pin and not use_rsa:
|
||||
raise ValueError("Must select at least one security factor (PIN or RSA key)")
|
||||
|
||||
|
||||
debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, "
|
||||
f"passphrase_words={passphrase_words}")
|
||||
|
||||
|
||||
# Generate passphrase (single, not daily)
|
||||
passphrase = generate_phrase(passphrase_words)
|
||||
|
||||
|
||||
# Generate PIN if requested
|
||||
pin = _generate_pin(pin_length) if use_pin else None
|
||||
|
||||
|
||||
# Generate RSA key if requested
|
||||
rsa_key_pem = None
|
||||
if use_rsa:
|
||||
rsa_key_obj = _generate_rsa_key(rsa_bits)
|
||||
rsa_key_bytes = export_rsa_key_pem(rsa_key_obj, rsa_password)
|
||||
rsa_key_pem = rsa_key_bytes.decode('utf-8')
|
||||
|
||||
|
||||
# Create Credentials object (v3.2.0 format)
|
||||
creds = Credentials(
|
||||
passphrase=passphrase,
|
||||
@@ -162,6 +164,6 @@ def generate_credentials(
|
||||
rsa_bits=rsa_bits if use_rsa else None,
|
||||
words_per_passphrase=passphrase_words,
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"Credentials generated: {creds.total_entropy} bits total entropy")
|
||||
return creds
|
||||
|
||||
@@ -4,40 +4,40 @@ Stegasoo Image Utilities (v3.2.0)
|
||||
Functions for analyzing images and comparing capacity.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .models import ImageInfo, CapacityComparison
|
||||
from .steganography import calculate_capacity, has_dct_support
|
||||
from .constants import EMBED_MODE_LSB, EMBED_MODE_DCT
|
||||
from .constants import EMBED_MODE_LSB
|
||||
from .debug import debug
|
||||
from .models import CapacityComparison, ImageInfo
|
||||
from .steganography import calculate_capacity, has_dct_support
|
||||
|
||||
|
||||
def get_image_info(image_data: bytes) -> ImageInfo:
|
||||
"""
|
||||
Get detailed information about an image.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Image file bytes
|
||||
|
||||
|
||||
Returns:
|
||||
ImageInfo with dimensions, format, capacity estimates
|
||||
|
||||
|
||||
Example:
|
||||
>>> info = get_image_info(carrier_bytes)
|
||||
>>> print(f"{info.width}x{info.height}, {info.lsb_capacity_kb} KB capacity")
|
||||
"""
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
|
||||
width, height = img.size
|
||||
pixels = width * height
|
||||
format_str = img.format or "Unknown"
|
||||
mode = img.mode
|
||||
|
||||
|
||||
# Calculate LSB capacity
|
||||
lsb_capacity = calculate_capacity(image_data, bits_per_channel=1)
|
||||
|
||||
|
||||
# Calculate DCT capacity if available
|
||||
dct_capacity = None
|
||||
if has_dct_support():
|
||||
@@ -47,7 +47,7 @@ def get_image_info(image_data: bytes) -> ImageInfo:
|
||||
dct_capacity = dct_info.usable_capacity_bytes
|
||||
except Exception as e:
|
||||
debug.print(f"Could not calculate DCT capacity: {e}")
|
||||
|
||||
|
||||
info = ImageInfo(
|
||||
width=width,
|
||||
height=height,
|
||||
@@ -60,27 +60,27 @@ def get_image_info(image_data: bytes) -> ImageInfo:
|
||||
dct_capacity_bytes=dct_capacity,
|
||||
dct_capacity_kb=dct_capacity / 1024 if dct_capacity else None,
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"Image info: {width}x{height}, LSB={lsb_capacity} bytes, "
|
||||
f"DCT={dct_capacity or 'N/A'} bytes")
|
||||
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def compare_capacity(
|
||||
carrier_image: bytes,
|
||||
reference_photo: Optional[bytes] = None,
|
||||
reference_photo: bytes | None = None,
|
||||
) -> CapacityComparison:
|
||||
"""
|
||||
Compare embedding capacity between LSB and DCT modes.
|
||||
|
||||
|
||||
Args:
|
||||
carrier_image: Carrier image bytes
|
||||
reference_photo: Optional reference photo (not used in v3.2.0, kept for API compatibility)
|
||||
|
||||
|
||||
Returns:
|
||||
CapacityComparison with capacity info for both modes
|
||||
|
||||
|
||||
Example:
|
||||
>>> comparison = compare_capacity(carrier_bytes)
|
||||
>>> print(f"LSB: {comparison.lsb_kb:.1f} KB")
|
||||
@@ -88,16 +88,16 @@ def compare_capacity(
|
||||
"""
|
||||
img = Image.open(io.BytesIO(carrier_image))
|
||||
width, height = img.size
|
||||
|
||||
|
||||
# LSB capacity
|
||||
lsb_bytes = calculate_capacity(carrier_image, bits_per_channel=1)
|
||||
lsb_kb = lsb_bytes / 1024
|
||||
|
||||
|
||||
# DCT capacity
|
||||
dct_available = has_dct_support()
|
||||
dct_bytes = None
|
||||
dct_kb = None
|
||||
|
||||
|
||||
if dct_available:
|
||||
try:
|
||||
from .dct_steganography import calculate_dct_capacity
|
||||
@@ -107,7 +107,7 @@ def compare_capacity(
|
||||
except Exception as e:
|
||||
debug.print(f"DCT capacity calculation failed: {e}")
|
||||
dct_available = False
|
||||
|
||||
|
||||
comparison = CapacityComparison(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
@@ -121,9 +121,9 @@ def compare_capacity(
|
||||
dct_output_formats=["PNG (grayscale)", "JPEG (grayscale)"] if dct_available else None,
|
||||
dct_ratio_vs_lsb=(dct_bytes / lsb_bytes * 100) if dct_bytes else None,
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"Capacity comparison: LSB={lsb_kb:.1f}KB, DCT={dct_kb or 'N/A'}KB")
|
||||
|
||||
|
||||
return comparison
|
||||
|
||||
|
||||
@@ -134,27 +134,27 @@ def validate_carrier_capacity(
|
||||
) -> dict:
|
||||
"""
|
||||
Check if a payload will fit in a carrier image.
|
||||
|
||||
|
||||
Args:
|
||||
carrier_image: Carrier image bytes
|
||||
payload_size: Size of payload in bytes
|
||||
embed_mode: 'lsb' or 'dct'
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'fits', 'capacity', 'usage_percent', 'headroom'
|
||||
"""
|
||||
from .steganography import calculate_capacity_by_mode
|
||||
|
||||
|
||||
capacity_info = calculate_capacity_by_mode(carrier_image, embed_mode)
|
||||
capacity = capacity_info['capacity_bytes']
|
||||
|
||||
|
||||
# Add encryption overhead estimate
|
||||
estimated_size = payload_size + 200 # Approximate overhead
|
||||
|
||||
|
||||
fits = estimated_size <= capacity
|
||||
usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0
|
||||
headroom = capacity - estimated_size
|
||||
|
||||
|
||||
return {
|
||||
'fits': fits,
|
||||
'capacity': capacity,
|
||||
|
||||
@@ -10,53 +10,57 @@ Changes in v3.2.0:
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from typing import Optional, Dict, Union
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from .constants import (
|
||||
DAY_NAMES,
|
||||
MIN_PIN_LENGTH, MAX_PIN_LENGTH, DEFAULT_PIN_LENGTH,
|
||||
MIN_PASSPHRASE_WORDS, MAX_PASSPHRASE_WORDS, DEFAULT_PASSPHRASE_WORDS,
|
||||
MIN_RSA_BITS, VALID_RSA_SIZES, DEFAULT_RSA_BITS,
|
||||
DEFAULT_PASSPHRASE_WORDS,
|
||||
DEFAULT_PIN_LENGTH,
|
||||
DEFAULT_RSA_BITS,
|
||||
MAX_PASSPHRASE_WORDS,
|
||||
MAX_PIN_LENGTH,
|
||||
MIN_PASSPHRASE_WORDS,
|
||||
MIN_PIN_LENGTH,
|
||||
VALID_RSA_SIZES,
|
||||
get_wordlist,
|
||||
)
|
||||
from .models import Credentials, KeyInfo
|
||||
from .exceptions import KeyGenerationError, KeyPasswordError
|
||||
from .debug import debug
|
||||
from .exceptions import KeyGenerationError, KeyPasswordError
|
||||
from .models import Credentials, KeyInfo
|
||||
|
||||
|
||||
def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
|
||||
"""
|
||||
Generate a random PIN.
|
||||
|
||||
|
||||
PINs never start with zero for usability.
|
||||
|
||||
|
||||
Args:
|
||||
length: PIN length (6-9 digits)
|
||||
|
||||
|
||||
Returns:
|
||||
PIN string
|
||||
|
||||
|
||||
Example:
|
||||
>>> generate_pin(6)
|
||||
"812345"
|
||||
"""
|
||||
debug.validate(MIN_PIN_LENGTH <= length <= MAX_PIN_LENGTH,
|
||||
f"PIN length must be between {MIN_PIN_LENGTH} and {MAX_PIN_LENGTH}")
|
||||
|
||||
|
||||
length = max(MIN_PIN_LENGTH, min(MAX_PIN_LENGTH, length))
|
||||
|
||||
|
||||
# First digit: 1-9 (no leading zero)
|
||||
first_digit = str(secrets.randbelow(9) + 1)
|
||||
|
||||
|
||||
# Remaining digits: 0-9
|
||||
rest = ''.join(str(secrets.randbelow(10)) for _ in range(length - 1))
|
||||
|
||||
|
||||
pin = first_digit + rest
|
||||
debug.print(f"Generated PIN: {pin}")
|
||||
return pin
|
||||
@@ -65,23 +69,23 @@ def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
|
||||
def generate_phrase(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> str:
|
||||
"""
|
||||
Generate a random passphrase from BIP-39 wordlist.
|
||||
|
||||
|
||||
Args:
|
||||
words_per_phrase: Number of words (3-12)
|
||||
|
||||
|
||||
Returns:
|
||||
Space-separated phrase
|
||||
|
||||
|
||||
Example:
|
||||
>>> generate_phrase(4)
|
||||
"apple forest thunder mountain"
|
||||
"""
|
||||
debug.validate(MIN_PASSPHRASE_WORDS <= words_per_phrase <= MAX_PASSPHRASE_WORDS,
|
||||
f"Words per phrase must be between {MIN_PASSPHRASE_WORDS} and {MAX_PASSPHRASE_WORDS}")
|
||||
|
||||
|
||||
words_per_phrase = max(MIN_PASSPHRASE_WORDS, min(MAX_PASSPHRASE_WORDS, words_per_phrase))
|
||||
wordlist = get_wordlist()
|
||||
|
||||
|
||||
words = [secrets.choice(wordlist) for _ in range(words_per_phrase)]
|
||||
phrase = ' '.join(words)
|
||||
debug.print(f"Generated phrase: {phrase}")
|
||||
@@ -92,19 +96,19 @@ def generate_phrase(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> str:
|
||||
generate_passphrase = generate_phrase
|
||||
|
||||
|
||||
def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> Dict[str, str]:
|
||||
def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> dict[str, str]:
|
||||
"""
|
||||
Generate phrases for all days of the week.
|
||||
|
||||
|
||||
DEPRECATED in v3.2.0: Use generate_phrase() for single passphrase.
|
||||
Kept for legacy compatibility and organizational use cases.
|
||||
|
||||
|
||||
Args:
|
||||
words_per_phrase: Number of words per phrase (3-12)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict mapping day names to phrases
|
||||
|
||||
|
||||
Example:
|
||||
>>> generate_day_phrases(3)
|
||||
{'Monday': 'apple forest thunder', 'Tuesday': 'banana river lightning', ...}
|
||||
@@ -116,7 +120,7 @@ def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> Di
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
|
||||
phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES}
|
||||
debug.print(f"Generated phrases for {len(phrases)} days")
|
||||
return phrases
|
||||
@@ -125,16 +129,16 @@ def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> Di
|
||||
def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
|
||||
"""
|
||||
Generate an RSA private key.
|
||||
|
||||
|
||||
Args:
|
||||
bits: Key size (2048, 3072, or 4096)
|
||||
|
||||
|
||||
Returns:
|
||||
RSA private key object
|
||||
|
||||
|
||||
Raises:
|
||||
KeyGenerationError: If generation fails
|
||||
|
||||
|
||||
Example:
|
||||
>>> key = generate_rsa_key(2048)
|
||||
>>> key.key_size
|
||||
@@ -142,10 +146,10 @@ def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
|
||||
"""
|
||||
debug.validate(bits in VALID_RSA_SIZES,
|
||||
f"RSA key size must be one of {VALID_RSA_SIZES}")
|
||||
|
||||
|
||||
if bits not in VALID_RSA_SIZES:
|
||||
bits = DEFAULT_RSA_BITS
|
||||
|
||||
|
||||
debug.print(f"Generating {bits}-bit RSA key...")
|
||||
try:
|
||||
key = rsa.generate_private_key(
|
||||
@@ -162,18 +166,18 @@ def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
|
||||
|
||||
def export_rsa_key_pem(
|
||||
private_key: rsa.RSAPrivateKey,
|
||||
password: Optional[str] = None
|
||||
password: str | None = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Export RSA key to PEM format.
|
||||
|
||||
|
||||
Args:
|
||||
private_key: RSA private key object
|
||||
password: Optional password for encryption
|
||||
|
||||
|
||||
Returns:
|
||||
PEM-encoded key bytes
|
||||
|
||||
|
||||
Example:
|
||||
>>> key = generate_rsa_key()
|
||||
>>> pem = export_rsa_key_pem(key)
|
||||
@@ -181,19 +185,16 @@ def export_rsa_key_pem(
|
||||
b'-----BEGIN PRIVATE KEY-----\\nMIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYw'
|
||||
"""
|
||||
debug.validate(private_key is not None, "Private key cannot be None")
|
||||
|
||||
encryption_algorithm: Union[
|
||||
serialization.BestAvailableEncryption,
|
||||
serialization.NoEncryption
|
||||
]
|
||||
|
||||
|
||||
encryption_algorithm: serialization.BestAvailableEncryption | serialization.NoEncryption
|
||||
|
||||
if password:
|
||||
encryption_algorithm = serialization.BestAvailableEncryption(password.encode())
|
||||
debug.print("Exporting RSA key with encryption")
|
||||
else:
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
debug.print("Exporting RSA key without encryption")
|
||||
|
||||
|
||||
return private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
@@ -203,39 +204,39 @@ def export_rsa_key_pem(
|
||||
|
||||
def load_rsa_key(
|
||||
key_data: bytes,
|
||||
password: Optional[str] = None
|
||||
password: str | None = None
|
||||
) -> rsa.RSAPrivateKey:
|
||||
"""
|
||||
Load RSA private key from PEM data.
|
||||
|
||||
|
||||
Args:
|
||||
key_data: PEM-encoded key bytes
|
||||
password: Password if key is encrypted
|
||||
|
||||
|
||||
Returns:
|
||||
RSA private key object
|
||||
|
||||
|
||||
Raises:
|
||||
KeyPasswordError: If password is wrong or missing
|
||||
KeyGenerationError: If key is invalid
|
||||
|
||||
|
||||
Example:
|
||||
>>> key = load_rsa_key(pem_data, "my_password")
|
||||
"""
|
||||
debug.validate(key_data is not None and len(key_data) > 0,
|
||||
"Key data cannot be empty")
|
||||
|
||||
|
||||
try:
|
||||
pwd_bytes = password.encode() if password else None
|
||||
debug.print(f"Loading RSA key (encrypted: {bool(password)})")
|
||||
key: PrivateKeyTypes = load_pem_private_key(
|
||||
key_data, password=pwd_bytes, backend=default_backend()
|
||||
)
|
||||
|
||||
|
||||
# Verify it's an RSA key
|
||||
if not isinstance(key, rsa.RSAPrivateKey):
|
||||
raise KeyGenerationError(f"Expected RSA key, got {type(key).__name__}")
|
||||
|
||||
|
||||
debug.print(f"RSA key loaded: {key.key_size} bits")
|
||||
return key
|
||||
except TypeError:
|
||||
@@ -253,17 +254,17 @@ def load_rsa_key(
|
||||
raise KeyGenerationError(f"Could not load RSA key: {e}") from e
|
||||
|
||||
|
||||
def get_key_info(key_data: bytes, password: Optional[str] = None) -> KeyInfo:
|
||||
def get_key_info(key_data: bytes, password: str | None = None) -> KeyInfo:
|
||||
"""
|
||||
Get information about an RSA key.
|
||||
|
||||
|
||||
Args:
|
||||
key_data: PEM-encoded key bytes
|
||||
password: Password if key is encrypted
|
||||
|
||||
|
||||
Returns:
|
||||
KeyInfo with key size and encryption status
|
||||
|
||||
|
||||
Example:
|
||||
>>> info = get_key_info(pem_data)
|
||||
>>> info.key_size
|
||||
@@ -274,15 +275,15 @@ def get_key_info(key_data: bytes, password: Optional[str] = None) -> KeyInfo:
|
||||
debug.print("Getting RSA key info")
|
||||
# Check if encrypted
|
||||
is_encrypted = b'ENCRYPTED' in key_data
|
||||
|
||||
|
||||
private_key = load_rsa_key(key_data, password)
|
||||
|
||||
|
||||
info = KeyInfo(
|
||||
key_size=private_key.key_size,
|
||||
is_encrypted=is_encrypted,
|
||||
pem_data=key_data
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"Key info: {info.key_size} bits, encrypted: {info.is_encrypted}")
|
||||
return info
|
||||
|
||||
@@ -293,14 +294,14 @@ def generate_credentials(
|
||||
pin_length: int = DEFAULT_PIN_LENGTH,
|
||||
rsa_bits: int = DEFAULT_RSA_BITS,
|
||||
passphrase_words: int = DEFAULT_PASSPHRASE_WORDS,
|
||||
rsa_password: Optional[str] = None,
|
||||
rsa_password: str | None = None,
|
||||
) -> Credentials:
|
||||
"""
|
||||
Generate a complete set of credentials.
|
||||
|
||||
|
||||
v3.2.0: Now generates a single passphrase instead of daily phrases.
|
||||
At least one of use_pin or use_rsa must be True.
|
||||
|
||||
|
||||
Args:
|
||||
use_pin: Whether to generate a PIN
|
||||
use_rsa: Whether to generate an RSA key
|
||||
@@ -308,13 +309,13 @@ def generate_credentials(
|
||||
rsa_bits: RSA key size if generating (default 2048)
|
||||
passphrase_words: Words in passphrase (default 4)
|
||||
rsa_password: Optional password for RSA key encryption
|
||||
|
||||
|
||||
Returns:
|
||||
Credentials object with passphrase, PIN, and/or RSA key
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If neither PIN nor RSA is selected
|
||||
|
||||
|
||||
Example:
|
||||
>>> creds = generate_credentials(use_pin=True, use_rsa=False)
|
||||
>>> creds.passphrase
|
||||
@@ -324,25 +325,25 @@ def generate_credentials(
|
||||
"""
|
||||
debug.validate(use_pin or use_rsa,
|
||||
"Must select at least one security factor (PIN or RSA key)")
|
||||
|
||||
|
||||
if not use_pin and not use_rsa:
|
||||
raise ValueError("Must select at least one security factor (PIN or RSA key)")
|
||||
|
||||
|
||||
debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, "
|
||||
f"passphrase_words={passphrase_words}")
|
||||
|
||||
|
||||
# Generate single passphrase (v3.2.0 - no daily rotation)
|
||||
passphrase = generate_phrase(passphrase_words)
|
||||
|
||||
|
||||
# Generate PIN if requested
|
||||
pin = generate_pin(pin_length) if use_pin else None
|
||||
|
||||
|
||||
# Generate RSA key if requested
|
||||
rsa_key_pem = None
|
||||
if use_rsa:
|
||||
rsa_key_obj = generate_rsa_key(rsa_bits)
|
||||
rsa_key_pem = export_rsa_key_pem(rsa_key_obj, rsa_password).decode('utf-8')
|
||||
|
||||
|
||||
# Create Credentials object (v3.2.0 format with single passphrase)
|
||||
creds = Credentials(
|
||||
passphrase=passphrase,
|
||||
@@ -351,7 +352,7 @@ def generate_credentials(
|
||||
rsa_bits=rsa_bits if use_rsa else None,
|
||||
words_per_passphrase=passphrase_words,
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"Credentials generated: {creds.total_entropy} bits total entropy")
|
||||
return creds
|
||||
|
||||
@@ -369,19 +370,19 @@ def generate_credentials_legacy(
|
||||
) -> dict:
|
||||
"""
|
||||
Generate credentials in legacy format (v3.1.0 style with daily phrases).
|
||||
|
||||
|
||||
DEPRECATED: Use generate_credentials() for v3.2.0 format.
|
||||
|
||||
|
||||
This function exists only for migration tools that need to work with
|
||||
old-format credentials.
|
||||
|
||||
|
||||
Args:
|
||||
use_pin: Whether to generate a PIN
|
||||
use_rsa: Whether to generate an RSA key
|
||||
pin_length: PIN length if generating
|
||||
rsa_bits: RSA key size if generating
|
||||
words_per_phrase: Words per daily phrase
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'phrases' (dict), 'pin', 'rsa_key_pem', etc.
|
||||
"""
|
||||
@@ -392,20 +393,20 @@ def generate_credentials_legacy(
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
|
||||
if not use_pin and not use_rsa:
|
||||
raise ValueError("Must select at least one security factor (PIN or RSA key)")
|
||||
|
||||
|
||||
# Generate daily phrases (old format)
|
||||
phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES}
|
||||
|
||||
|
||||
pin = generate_pin(pin_length) if use_pin else None
|
||||
|
||||
|
||||
rsa_key_pem = None
|
||||
if use_rsa:
|
||||
rsa_key_obj = generate_rsa_key(rsa_bits)
|
||||
rsa_key_pem = export_rsa_key_pem(rsa_key_obj).decode('utf-8')
|
||||
|
||||
|
||||
return {
|
||||
'phrases': phrases,
|
||||
'pin': pin,
|
||||
|
||||
@@ -12,50 +12,48 @@ Changes in v3.2.0:
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from typing import Optional, Union, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class Credentials:
|
||||
"""
|
||||
Generated credentials for encoding/decoding.
|
||||
|
||||
|
||||
v3.2.0: Simplified to use single passphrase instead of daily rotation.
|
||||
"""
|
||||
passphrase: str # Single passphrase (no daily rotation)
|
||||
pin: Optional[str] = None
|
||||
rsa_key_pem: Optional[str] = None
|
||||
rsa_bits: Optional[int] = None
|
||||
pin: str | None = None
|
||||
rsa_key_pem: str | None = None
|
||||
rsa_bits: int | None = None
|
||||
words_per_passphrase: int = 4 # Increased from 3 in v3.1.0
|
||||
|
||||
|
||||
# Optional: backup passphrases for multi-factor or rotation
|
||||
backup_passphrases: Optional[list[str]] = None
|
||||
|
||||
backup_passphrases: list[str] | None = None
|
||||
|
||||
@property
|
||||
def passphrase_entropy(self) -> int:
|
||||
"""Entropy in bits from passphrase (~11 bits per BIP-39 word)."""
|
||||
return self.words_per_passphrase * 11
|
||||
|
||||
|
||||
@property
|
||||
def pin_entropy(self) -> int:
|
||||
"""Entropy in bits from PIN (~3.32 bits per digit)."""
|
||||
if self.pin:
|
||||
return int(len(self.pin) * 3.32)
|
||||
return 0
|
||||
|
||||
|
||||
@property
|
||||
def rsa_entropy(self) -> int:
|
||||
"""Effective entropy from RSA key."""
|
||||
if self.rsa_key_pem and self.rsa_bits:
|
||||
return min(self.rsa_bits // 16, 128)
|
||||
return 0
|
||||
|
||||
|
||||
@property
|
||||
def total_entropy(self) -> int:
|
||||
"""Total entropy in bits (excluding reference photo)."""
|
||||
return self.passphrase_entropy + self.pin_entropy + self.rsa_entropy
|
||||
|
||||
|
||||
# Legacy property for compatibility
|
||||
@property
|
||||
def phrase_entropy(self) -> int:
|
||||
@@ -68,23 +66,23 @@ class FilePayload:
|
||||
"""Represents a file to be embedded."""
|
||||
data: bytes
|
||||
filename: str
|
||||
mime_type: Optional[str] = None
|
||||
|
||||
mime_type: str | None = None
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filepath: str, filename: Optional[str] = None) -> 'FilePayload':
|
||||
def from_file(cls, filepath: str, filename: str | None = None) -> 'FilePayload':
|
||||
"""Create FilePayload from a file path."""
|
||||
from pathlib import Path
|
||||
import mimetypes
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(filepath)
|
||||
data = path.read_bytes()
|
||||
name = filename or path.name
|
||||
mime, _ = mimetypes.guess_type(name)
|
||||
|
||||
|
||||
return cls(data=data, filename=name, mime_type=mime)
|
||||
|
||||
|
||||
@@ -92,23 +90,23 @@ class FilePayload:
|
||||
class EncodeInput:
|
||||
"""
|
||||
Input parameters for encoding a message.
|
||||
|
||||
|
||||
v3.2.0: Removed date_str (date no longer used in crypto).
|
||||
"""
|
||||
message: Union[str, bytes, FilePayload] # Text, raw bytes, or file
|
||||
message: str | bytes | FilePayload # Text, raw bytes, or file
|
||||
reference_photo: bytes
|
||||
carrier_image: bytes
|
||||
passphrase: str # Renamed from day_phrase
|
||||
pin: str = ""
|
||||
rsa_key_data: Optional[bytes] = None
|
||||
rsa_password: Optional[str] = None
|
||||
rsa_key_data: bytes | None = None
|
||||
rsa_password: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncodeResult:
|
||||
"""
|
||||
Result of encoding operation.
|
||||
|
||||
|
||||
v3.2.0: date_used is now optional/cosmetic (not used in crypto).
|
||||
"""
|
||||
stego_image: bytes
|
||||
@@ -116,8 +114,8 @@ class EncodeResult:
|
||||
pixels_modified: int
|
||||
total_pixels: int
|
||||
capacity_used: float # 0.0 - 1.0
|
||||
date_used: Optional[str] = None # Cosmetic only (for filename organization)
|
||||
|
||||
date_used: str | None = None # Cosmetic only (for filename organization)
|
||||
|
||||
@property
|
||||
def capacity_percent(self) -> float:
|
||||
"""Capacity used as percentage."""
|
||||
@@ -128,54 +126,54 @@ class EncodeResult:
|
||||
class DecodeInput:
|
||||
"""
|
||||
Input parameters for decoding a message.
|
||||
|
||||
|
||||
v3.2.0: Renamed day_phrase → passphrase, no date needed.
|
||||
"""
|
||||
stego_image: bytes
|
||||
reference_photo: bytes
|
||||
passphrase: str # Renamed from day_phrase
|
||||
pin: str = ""
|
||||
rsa_key_data: Optional[bytes] = None
|
||||
rsa_password: Optional[str] = None
|
||||
rsa_key_data: bytes | None = None
|
||||
rsa_password: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeResult:
|
||||
"""
|
||||
Result of decoding operation.
|
||||
|
||||
|
||||
v3.2.0: date_encoded is always None (date removed from crypto).
|
||||
"""
|
||||
payload_type: str # 'text' or 'file'
|
||||
message: Optional[str] = None # For text payloads
|
||||
file_data: Optional[bytes] = None # For file payloads
|
||||
filename: Optional[str] = None # Original filename for file payloads
|
||||
mime_type: Optional[str] = None # MIME type hint
|
||||
date_encoded: Optional[str] = None # Always None in v3.2.0 (kept for compatibility)
|
||||
|
||||
message: str | None = None # For text payloads
|
||||
file_data: bytes | None = None # For file payloads
|
||||
filename: str | None = None # Original filename for file payloads
|
||||
mime_type: str | None = None # MIME type hint
|
||||
date_encoded: str | None = None # Always None in v3.2.0 (kept for compatibility)
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
return self.payload_type == 'file'
|
||||
|
||||
|
||||
@property
|
||||
def is_text(self) -> bool:
|
||||
return self.payload_type == 'text'
|
||||
|
||||
def get_content(self) -> Union[str, bytes]:
|
||||
|
||||
def get_content(self) -> str | bytes:
|
||||
"""Get the decoded content (text or bytes)."""
|
||||
if self.is_text:
|
||||
return self.message or ""
|
||||
return self.file_data or b""
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class EmbedStats:
|
||||
"""Statistics from image embedding."""
|
||||
pixels_modified: int
|
||||
total_pixels: int
|
||||
capacity_used: float
|
||||
bytes_embedded: int
|
||||
|
||||
|
||||
@property
|
||||
def modification_percent(self) -> float:
|
||||
"""Percentage of pixels modified."""
|
||||
@@ -196,16 +194,16 @@ class ValidationResult:
|
||||
is_valid: bool
|
||||
error_message: str = ""
|
||||
details: dict = field(default_factory=dict)
|
||||
warning: Optional[str] = None # v3.2.0: Added for passphrase length warnings
|
||||
|
||||
warning: str | None = None # v3.2.0: Added for passphrase length warnings
|
||||
|
||||
@classmethod
|
||||
def ok(cls, warning: Optional[str] = None, **details) -> 'ValidationResult':
|
||||
def ok(cls, warning: str | None = None, **details) -> 'ValidationResult':
|
||||
"""Create a successful validation result."""
|
||||
result = cls(is_valid=True, details=details)
|
||||
if warning:
|
||||
result.warning = warning
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def error(cls, message: str, **details) -> 'ValidationResult':
|
||||
"""Create a failed validation result."""
|
||||
@@ -227,8 +225,8 @@ class ImageInfo:
|
||||
file_size: int
|
||||
lsb_capacity_bytes: int
|
||||
lsb_capacity_kb: float
|
||||
dct_capacity_bytes: Optional[int] = None
|
||||
dct_capacity_kb: Optional[float] = None
|
||||
dct_capacity_bytes: int | None = None
|
||||
dct_capacity_kb: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -241,24 +239,24 @@ class CapacityComparison:
|
||||
lsb_kb: float
|
||||
lsb_output_format: str
|
||||
dct_available: bool
|
||||
dct_bytes: Optional[int] = None
|
||||
dct_kb: Optional[float] = None
|
||||
dct_output_formats: Optional[List[str]] = None
|
||||
dct_ratio_vs_lsb: Optional[float] = None
|
||||
dct_bytes: int | None = None
|
||||
dct_kb: float | None = None
|
||||
dct_output_formats: list[str] | None = None
|
||||
dct_ratio_vs_lsb: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateResult:
|
||||
"""Result of credential generation."""
|
||||
passphrase: str
|
||||
pin: Optional[str] = None
|
||||
rsa_key_pem: Optional[str] = None
|
||||
pin: str | None = None
|
||||
rsa_key_pem: str | None = None
|
||||
passphrase_words: int = 4
|
||||
passphrase_entropy: int = 0
|
||||
pin_entropy: int = 0
|
||||
rsa_entropy: int = 0
|
||||
total_entropy: int = 0
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
lines = [
|
||||
"Generated Credentials:",
|
||||
|
||||
@@ -10,10 +10,9 @@ IMPROVEMENTS IN THIS VERSION:
|
||||
- Improved error messages
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import zlib
|
||||
import base64
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -27,20 +26,19 @@ except ImportError:
|
||||
|
||||
# QR code reading
|
||||
try:
|
||||
from pyzbar.pyzbar import decode as pyzbar_decode
|
||||
from pyzbar.pyzbar import ZBarSymbol
|
||||
from pyzbar.pyzbar import decode as pyzbar_decode
|
||||
HAS_QRCODE_READ = True
|
||||
except ImportError:
|
||||
HAS_QRCODE_READ = False
|
||||
|
||||
|
||||
from .constants import (
|
||||
QR_MAX_BINARY,
|
||||
QR_CROP_PADDING_PERCENT,
|
||||
QR_CROP_MIN_PADDING_PX,
|
||||
QR_CROP_PADDING_PERCENT,
|
||||
QR_MAX_BINARY,
|
||||
)
|
||||
|
||||
|
||||
# Constants
|
||||
COMPRESSION_PREFIX = "STEGASOO-Z:"
|
||||
|
||||
@@ -48,10 +46,10 @@ COMPRESSION_PREFIX = "STEGASOO-Z:"
|
||||
def compress_data(data: str) -> str:
|
||||
"""
|
||||
Compress string data for QR code storage.
|
||||
|
||||
|
||||
Args:
|
||||
data: String to compress
|
||||
|
||||
|
||||
Returns:
|
||||
Compressed string with STEGASOO-Z: prefix
|
||||
"""
|
||||
@@ -63,19 +61,19 @@ def compress_data(data: str) -> str:
|
||||
def decompress_data(data: str) -> str:
|
||||
"""
|
||||
Decompress data from QR code.
|
||||
|
||||
|
||||
Args:
|
||||
data: Compressed string with STEGASOO-Z: prefix
|
||||
|
||||
|
||||
Returns:
|
||||
Original uncompressed string
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If data is not valid compressed format
|
||||
"""
|
||||
if not data.startswith(COMPRESSION_PREFIX):
|
||||
raise ValueError("Data is not in compressed format")
|
||||
|
||||
|
||||
encoded = data[len(COMPRESSION_PREFIX):]
|
||||
compressed = base64.b64decode(encoded)
|
||||
return zlib.decompress(compressed).decode('utf-8')
|
||||
@@ -84,7 +82,7 @@ def decompress_data(data: str) -> str:
|
||||
def normalize_pem(pem_data: str) -> str:
|
||||
"""
|
||||
Normalize PEM data to ensure proper formatting for cryptography library.
|
||||
|
||||
|
||||
The cryptography library is very particular about PEM formatting.
|
||||
This function handles all common issues from QR code extraction:
|
||||
- Inconsistent line endings (CRLF, LF, CR)
|
||||
@@ -93,24 +91,24 @@ def normalize_pem(pem_data: str) -> str:
|
||||
- Non-ASCII characters
|
||||
- Incorrect base64 padding
|
||||
- Malformed headers/footers
|
||||
|
||||
|
||||
Args:
|
||||
pem_data: Raw PEM string from QR code
|
||||
|
||||
|
||||
Returns:
|
||||
Properly formatted PEM string that cryptography library will accept
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# Step 1: Normalize ALL line endings to \n
|
||||
pem_data = pem_data.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
# Step 2: Remove leading/trailing whitespace
|
||||
pem_data = pem_data.strip()
|
||||
|
||||
|
||||
# Step 3: Remove any non-ASCII characters (QR artifacts)
|
||||
pem_data = ''.join(char for char in pem_data if ord(char) < 128)
|
||||
|
||||
|
||||
# Step 4: Extract header, content, and footer with flexible regex
|
||||
# This handles variations like:
|
||||
# - "PRIVATE KEY" vs "RSA PRIVATE KEY"
|
||||
@@ -118,51 +116,51 @@ def normalize_pem(pem_data: str) -> str:
|
||||
# - Missing spaces
|
||||
pattern = r'(-----BEGIN[^-]*-----)(.*?)(-----END[^-]*-----)'
|
||||
match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE)
|
||||
|
||||
|
||||
if not match:
|
||||
# Fallback: try even more permissive pattern
|
||||
pattern = r'(-+BEGIN[^-]+-+)(.*?)(-+END[^-]+-+)'
|
||||
match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE)
|
||||
|
||||
|
||||
if not match:
|
||||
# Last resort: return original if can't parse
|
||||
return pem_data
|
||||
|
||||
|
||||
header_raw = match.group(1).strip()
|
||||
content_raw = match.group(2)
|
||||
footer_raw = match.group(3).strip()
|
||||
|
||||
|
||||
# Step 5: Normalize header and footer
|
||||
# Standardize spacing and ensure proper format
|
||||
header = re.sub(r'\s+', ' ', header_raw)
|
||||
footer = re.sub(r'\s+', ' ', footer_raw)
|
||||
|
||||
|
||||
# Ensure exactly 5 dashes on each side
|
||||
header = re.sub(r'^-+', '-----', header)
|
||||
header = re.sub(r'-+$', '-----', header)
|
||||
footer = re.sub(r'^-+', '-----', footer)
|
||||
footer = re.sub(r'-+$', '-----', footer)
|
||||
|
||||
|
||||
# Step 6: Clean the base64 content THOROUGHLY
|
||||
# Remove ALL whitespace: spaces, tabs, newlines
|
||||
# Keep only valid base64 characters: A-Z, a-z, 0-9, +, /, =
|
||||
content_clean = ''.join(
|
||||
char for char in content_raw
|
||||
char for char in content_raw
|
||||
if char.isalnum() or char in '+/='
|
||||
)
|
||||
|
||||
|
||||
# Double-check: remove any remaining invalid characters
|
||||
content_clean = re.sub(r'[^A-Za-z0-9+/=]', '', content_clean)
|
||||
|
||||
|
||||
# Step 7: Fix base64 padding
|
||||
# Base64 strings must be divisible by 4
|
||||
remainder = len(content_clean) % 4
|
||||
if remainder:
|
||||
content_clean += '=' * (4 - remainder)
|
||||
|
||||
|
||||
# Step 8: Split into 64-character lines (PEM standard)
|
||||
lines = [content_clean[i:i+64] for i in range(0, len(content_clean), 64)]
|
||||
|
||||
|
||||
# Step 9: Reconstruct with EXACT PEM formatting
|
||||
# Format: header\ncontent_line1\ncontent_line2\n...\nfooter\n
|
||||
return header + '\n' + '\n'.join(lines) + '\n' + footer + '\n'
|
||||
@@ -176,10 +174,10 @@ def is_compressed(data: str) -> bool:
|
||||
def auto_decompress(data: str) -> str:
|
||||
"""
|
||||
Automatically decompress data if compressed, otherwise return as-is.
|
||||
|
||||
|
||||
Args:
|
||||
data: Possibly compressed string
|
||||
|
||||
|
||||
Returns:
|
||||
Decompressed string
|
||||
"""
|
||||
@@ -196,11 +194,11 @@ def get_compressed_size(data: str) -> int:
|
||||
def can_fit_in_qr(data: str, compress: bool = False) -> bool:
|
||||
"""
|
||||
Check if data can fit in a QR code.
|
||||
|
||||
|
||||
Args:
|
||||
data: String data
|
||||
compress: Whether compression will be used
|
||||
|
||||
|
||||
Returns:
|
||||
True if data fits
|
||||
"""
|
||||
@@ -223,39 +221,39 @@ def generate_qr_code(
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate a QR code PNG from string data.
|
||||
|
||||
|
||||
Args:
|
||||
data: String data to encode
|
||||
compress: Whether to compress data first
|
||||
error_correction: QR error correction level (default: auto)
|
||||
|
||||
|
||||
Returns:
|
||||
PNG image bytes
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If qrcode library not available
|
||||
ValueError: If data too large for QR code
|
||||
"""
|
||||
if not HAS_QRCODE_WRITE:
|
||||
raise RuntimeError("qrcode library not installed. Run: pip install qrcode[pil]")
|
||||
|
||||
|
||||
qr_data = data
|
||||
|
||||
|
||||
# Compress if requested
|
||||
if compress:
|
||||
qr_data = compress_data(data)
|
||||
|
||||
|
||||
# Check size
|
||||
if len(qr_data.encode('utf-8')) > QR_MAX_BINARY:
|
||||
raise ValueError(
|
||||
f"Data too large for QR code ({len(qr_data)} bytes). "
|
||||
f"Maximum: {QR_MAX_BINARY} bytes"
|
||||
)
|
||||
|
||||
|
||||
# Use lower error correction for larger data
|
||||
if error_correction is None:
|
||||
error_correction = ERROR_CORRECT_L if len(qr_data) > 1000 else ERROR_CORRECT_M
|
||||
|
||||
|
||||
qr = qrcode.QRCode(
|
||||
version=None,
|
||||
error_correction=error_correction,
|
||||
@@ -264,25 +262,25 @@ def generate_qr_code(
|
||||
)
|
||||
qr.add_data(qr_data)
|
||||
qr.make(fit=True)
|
||||
|
||||
|
||||
img = qr.make_image(fill_color="black", back_color="white")
|
||||
|
||||
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format='PNG')
|
||||
buf.seek(0)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def read_qr_code(image_data: bytes) -> Optional[str]:
|
||||
def read_qr_code(image_data: bytes) -> str | None:
|
||||
"""
|
||||
Read QR code from image data.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Image bytes (PNG, JPG, etc.)
|
||||
|
||||
|
||||
Returns:
|
||||
Decoded string, or None if no QR code found
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If pyzbar library not available
|
||||
"""
|
||||
@@ -291,35 +289,35 @@ def read_qr_code(image_data: bytes) -> Optional[str]:
|
||||
"pyzbar library not installed. Run: pip install pyzbar\n"
|
||||
"Also requires system library: sudo apt-get install libzbar0"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
img: Image.Image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
|
||||
# Convert to RGB if necessary (pyzbar works best with RGB/grayscale)
|
||||
if img.mode not in ('RGB', 'L'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
|
||||
# Decode QR codes
|
||||
decoded = pyzbar_decode(img, symbols=[ZBarSymbol.QRCODE])
|
||||
|
||||
|
||||
if not decoded:
|
||||
return None
|
||||
|
||||
|
||||
# Return first QR code found
|
||||
result: str = decoded[0].data.decode('utf-8')
|
||||
return result
|
||||
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def read_qr_code_from_file(filepath: str) -> Optional[str]:
|
||||
def read_qr_code_from_file(filepath: str) -> str | None:
|
||||
"""
|
||||
Read QR code from image file.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Path to image file
|
||||
|
||||
|
||||
Returns:
|
||||
Decoded string, or None if no QR code found
|
||||
"""
|
||||
@@ -327,25 +325,25 @@ def read_qr_code_from_file(filepath: str) -> Optional[str]:
|
||||
return read_qr_code(f.read())
|
||||
|
||||
|
||||
def extract_key_from_qr(image_data: bytes) -> Optional[str]:
|
||||
def extract_key_from_qr(image_data: bytes) -> str | None:
|
||||
"""
|
||||
Extract RSA key from QR code image, auto-decompressing if needed.
|
||||
|
||||
|
||||
This function is more robust than the original, with better error handling
|
||||
and PEM normalization.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Image bytes containing QR code
|
||||
|
||||
|
||||
Returns:
|
||||
PEM-encoded RSA key string, or None if not found/invalid
|
||||
"""
|
||||
# Step 1: Read QR code
|
||||
qr_data = read_qr_code(image_data)
|
||||
|
||||
|
||||
if not qr_data:
|
||||
return None
|
||||
|
||||
|
||||
# Step 2: Auto-decompress if needed
|
||||
try:
|
||||
if is_compressed(qr_data):
|
||||
@@ -355,11 +353,11 @@ def extract_key_from_qr(image_data: bytes) -> Optional[str]:
|
||||
except Exception:
|
||||
# If decompression fails, try using data as-is
|
||||
key_pem = qr_data
|
||||
|
||||
|
||||
# Step 3: Validate it looks like a PEM key
|
||||
if '-----BEGIN' not in key_pem or '-----END' not in key_pem:
|
||||
return None
|
||||
|
||||
|
||||
# Step 4: Aggressively normalize PEM format
|
||||
# This is crucial - QR codes can introduce subtle formatting issues
|
||||
try:
|
||||
@@ -367,21 +365,21 @@ def extract_key_from_qr(image_data: bytes) -> Optional[str]:
|
||||
except Exception:
|
||||
# If normalization fails, return None rather than broken PEM
|
||||
return None
|
||||
|
||||
|
||||
# Step 5: Final validation - ensure it still looks like PEM
|
||||
if '-----BEGIN' in key_pem and '-----END' in key_pem:
|
||||
return key_pem
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_key_from_qr_file(filepath: str) -> Optional[str]:
|
||||
def extract_key_from_qr_file(filepath: str) -> str | None:
|
||||
"""
|
||||
Extract RSA key from QR code image file.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Path to image file containing QR code
|
||||
|
||||
|
||||
Returns:
|
||||
PEM-encoded RSA key string, or None if not found/invalid
|
||||
"""
|
||||
@@ -393,21 +391,21 @@ def detect_and_crop_qr(
|
||||
image_data: bytes,
|
||||
padding_percent: float = QR_CROP_PADDING_PERCENT,
|
||||
min_padding_px: int = QR_CROP_MIN_PADDING_PX
|
||||
) -> Optional[bytes]:
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Detect QR code in image and crop to it, handling rotation.
|
||||
|
||||
|
||||
Uses the QR code's corner coordinates to compute an axis-aligned
|
||||
bounding box, then adds padding to ensure rotated QR codes aren't clipped.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Input image bytes (PNG, JPG, etc.)
|
||||
padding_percent: Padding as fraction of QR size (default 10%)
|
||||
min_padding_px: Minimum padding in pixels (default 10)
|
||||
|
||||
|
||||
Returns:
|
||||
Cropped PNG image bytes, or None if no QR code found
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If pyzbar library not available
|
||||
"""
|
||||
@@ -416,27 +414,27 @@ def detect_and_crop_qr(
|
||||
"pyzbar library not installed. Run: pip install pyzbar\n"
|
||||
"Also requires system library: sudo apt-get install libzbar0"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
img: Image.Image = Image.open(io.BytesIO(image_data))
|
||||
original_mode = img.mode
|
||||
|
||||
|
||||
# Convert for pyzbar detection
|
||||
if img.mode not in ('RGB', 'L'):
|
||||
detect_img = img.convert('RGB')
|
||||
else:
|
||||
detect_img = img
|
||||
|
||||
|
||||
# Decode QR codes to get corner positions
|
||||
decoded = pyzbar_decode(detect_img, symbols=[ZBarSymbol.QRCODE])
|
||||
|
||||
|
||||
if not decoded:
|
||||
return None
|
||||
|
||||
|
||||
# Get the polygon corners of the first QR code
|
||||
# pyzbar returns a Polygon with Point objects (x, y attributes)
|
||||
polygon = decoded[0].polygon
|
||||
|
||||
|
||||
if len(polygon) < 4:
|
||||
# Fallback to rect if polygon not available
|
||||
rect = decoded[0].rect
|
||||
@@ -448,25 +446,25 @@ def detect_and_crop_qr(
|
||||
ys = [p.y for p in polygon]
|
||||
min_x, max_x = min(xs), max(xs)
|
||||
min_y, max_y = min(ys), max(ys)
|
||||
|
||||
|
||||
# Calculate QR dimensions and padding
|
||||
qr_width = max_x - min_x
|
||||
qr_height = max_y - min_y
|
||||
|
||||
|
||||
# Use larger dimension for padding calculation (handles rotation)
|
||||
qr_size = max(qr_width, qr_height)
|
||||
padding = max(int(qr_size * padding_percent), min_padding_px)
|
||||
|
||||
|
||||
# Calculate crop box with padding, clamped to image bounds
|
||||
img_width, img_height = img.size
|
||||
crop_left = max(0, min_x - padding)
|
||||
crop_top = max(0, min_y - padding)
|
||||
crop_right = min(img_width, max_x + padding)
|
||||
crop_bottom = min(img_height, max_y + padding)
|
||||
|
||||
|
||||
# Crop the original image (preserves original mode/quality)
|
||||
cropped = img.crop((crop_left, crop_top, crop_right, crop_bottom))
|
||||
|
||||
|
||||
# Convert to PNG bytes
|
||||
buf = io.BytesIO()
|
||||
# Preserve transparency if present
|
||||
@@ -476,7 +474,7 @@ def detect_and_crop_qr(
|
||||
cropped.save(buf, format='PNG')
|
||||
buf.seek(0)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log for debugging but return None for clean API
|
||||
import sys
|
||||
@@ -488,15 +486,15 @@ def detect_and_crop_qr_file(
|
||||
filepath: str,
|
||||
padding_percent: float = QR_CROP_PADDING_PERCENT,
|
||||
min_padding_px: int = QR_CROP_MIN_PADDING_PX
|
||||
) -> Optional[bytes]:
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Detect QR code in image file and crop to it.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Path to image file
|
||||
padding_percent: Padding as fraction of QR size (default 10%)
|
||||
min_padding_px: Minimum padding in pixels (default 10)
|
||||
|
||||
|
||||
Returns:
|
||||
Cropped PNG image bytes, or None if no QR code found
|
||||
"""
|
||||
|
||||
@@ -20,22 +20,24 @@ Changes in v3.2.0:
|
||||
|
||||
import io
|
||||
import struct
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from PIL import Image
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
||||
from PIL import Image
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .dct_steganography import DCTEmbedStats
|
||||
|
||||
from .models import EmbedStats, FilePayload
|
||||
from .exceptions import CapacityError, ExtractionError, EmbeddingError
|
||||
from .debug import debug
|
||||
from .constants import (
|
||||
EMBED_MODE_LSB,
|
||||
EMBED_MODE_DCT,
|
||||
EMBED_MODE_AUTO,
|
||||
EMBED_MODE_DCT,
|
||||
EMBED_MODE_LSB,
|
||||
VALID_EMBED_MODES,
|
||||
)
|
||||
|
||||
from .debug import debug
|
||||
from .exceptions import CapacityError, EmbeddingError
|
||||
from .models import EmbedStats, FilePayload
|
||||
|
||||
# Lossless formats that preserve LSB data
|
||||
LOSSLESS_FORMATS = {'PNG', 'BMP', 'TIFF'}
|
||||
@@ -103,10 +105,10 @@ def _get_dct_module():
|
||||
def has_dct_support() -> bool:
|
||||
"""
|
||||
Check if DCT steganography mode is available.
|
||||
|
||||
|
||||
Returns:
|
||||
True if scipy is installed and DCT functions work
|
||||
|
||||
|
||||
Example:
|
||||
>>> if has_dct_support():
|
||||
... result = encode(..., embed_mode='dct')
|
||||
@@ -122,26 +124,26 @@ def has_dct_support() -> bool:
|
||||
# FORMAT UTILITIES
|
||||
# =============================================================================
|
||||
|
||||
def get_output_format(input_format: Optional[str]) -> Tuple[str, str]:
|
||||
def get_output_format(input_format: str | None) -> tuple[str, str]:
|
||||
"""
|
||||
Determine the output format based on input format.
|
||||
|
||||
|
||||
Args:
|
||||
input_format: PIL format string of input image (e.g., 'JPEG', 'PNG')
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (PIL format string, file extension) for output
|
||||
Falls back to PNG for lossy or unknown formats.
|
||||
"""
|
||||
debug.validate(input_format is None or isinstance(input_format, str),
|
||||
"Input format must be string or None")
|
||||
|
||||
|
||||
if input_format and input_format.upper() in LOSSLESS_FORMATS:
|
||||
fmt = input_format.upper()
|
||||
ext = FORMAT_TO_EXT.get(fmt, 'png')
|
||||
debug.print(f"Using lossless format: {fmt} -> .{ext}")
|
||||
return fmt, ext
|
||||
|
||||
|
||||
debug.print(f"Input format {input_format} is lossy or unknown, defaulting to PNG")
|
||||
return 'PNG', 'png'
|
||||
|
||||
@@ -151,20 +153,20 @@ def get_output_format(input_format: Optional[str]) -> Tuple[str, str]:
|
||||
# =============================================================================
|
||||
|
||||
def will_fit(
|
||||
payload: Union[str, bytes, FilePayload, int],
|
||||
payload: str | bytes | FilePayload | int,
|
||||
carrier_image: bytes,
|
||||
bits_per_channel: int = 1,
|
||||
include_compression_estimate: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Check if a payload will fit in a carrier image (LSB mode).
|
||||
|
||||
|
||||
Args:
|
||||
payload: Message string, raw bytes, FilePayload, or size in bytes
|
||||
carrier_image: Carrier image bytes
|
||||
bits_per_channel: Bits to use per color channel (1-2)
|
||||
include_compression_estimate: Estimate compressed size
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with fits, capacity, usage info
|
||||
"""
|
||||
@@ -183,15 +185,15 @@ def will_fit(
|
||||
else:
|
||||
payload_data = payload
|
||||
payload_size = len(payload)
|
||||
|
||||
|
||||
capacity = calculate_capacity(carrier_image, bits_per_channel)
|
||||
|
||||
|
||||
# Estimate encrypted size with padding
|
||||
# Padding adds 64-319 bytes, rounded up to 256-byte boundary
|
||||
# Average case: ~190 bytes padding
|
||||
estimated_padding = 190
|
||||
estimated_encrypted_size = payload_size + estimated_padding + ENCRYPTION_OVERHEAD
|
||||
|
||||
|
||||
compressed_estimate = None
|
||||
if include_compression_estimate and payload_data is not None and len(payload_data) >= 64:
|
||||
try:
|
||||
@@ -203,11 +205,11 @@ def will_fit(
|
||||
estimated_encrypted_size = compressed_size + estimated_padding + ENCRYPTION_OVERHEAD
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
headroom = capacity - estimated_encrypted_size
|
||||
fits = headroom >= 0
|
||||
usage_percent = (estimated_encrypted_size / capacity * 100) if capacity > 0 else 100.0
|
||||
|
||||
|
||||
return {
|
||||
'fits': fits,
|
||||
'payload_size': payload_size,
|
||||
@@ -223,23 +225,23 @@ def will_fit(
|
||||
def calculate_capacity(image_data: bytes, bits_per_channel: int = 1) -> int:
|
||||
"""
|
||||
Calculate the maximum message capacity of an image (LSB mode).
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Image bytes
|
||||
bits_per_channel: Bits to use per color channel
|
||||
|
||||
|
||||
Returns:
|
||||
Maximum bytes that can be embedded (minus overhead)
|
||||
"""
|
||||
debug.validate(bits_per_channel in (1, 2),
|
||||
f"bits_per_channel must be 1 or 2, got {bits_per_channel}")
|
||||
|
||||
|
||||
img_file = Image.open(io.BytesIO(image_data))
|
||||
try:
|
||||
num_pixels = img_file.size[0] * img_file.size[1]
|
||||
bits_per_pixel = 3 * bits_per_channel
|
||||
max_bytes = (num_pixels * bits_per_pixel) // 8
|
||||
|
||||
|
||||
capacity = max(0, max_bytes - ENCRYPTION_OVERHEAD)
|
||||
debug.print(f"LSB capacity: {capacity} bytes at {bits_per_channel} bit(s)/channel")
|
||||
return capacity
|
||||
@@ -248,28 +250,28 @@ def calculate_capacity(image_data: bytes, bits_per_channel: int = 1) -> int:
|
||||
|
||||
|
||||
def calculate_capacity_by_mode(
|
||||
image_data: bytes,
|
||||
image_data: bytes,
|
||||
embed_mode: str = EMBED_MODE_LSB,
|
||||
bits_per_channel: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Calculate capacity for specified embedding mode.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Carrier image bytes
|
||||
embed_mode: 'lsb' or 'dct'
|
||||
bits_per_channel: Bits per channel for LSB mode
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with capacity information
|
||||
"""
|
||||
if embed_mode == EMBED_MODE_DCT:
|
||||
if not has_dct_support():
|
||||
raise ImportError("scipy required for DCT mode. Install: pip install scipy")
|
||||
|
||||
|
||||
dct_mod = _get_dct_module()
|
||||
dct_info = dct_mod.calculate_dct_capacity(image_data)
|
||||
|
||||
|
||||
return {
|
||||
'mode': EMBED_MODE_DCT,
|
||||
'capacity_bytes': dct_info.usable_capacity_bytes,
|
||||
@@ -285,7 +287,7 @@ def calculate_capacity_by_mode(
|
||||
width, height = img.size
|
||||
finally:
|
||||
img.close()
|
||||
|
||||
|
||||
return {
|
||||
'mode': EMBED_MODE_LSB,
|
||||
'capacity_bytes': capacity,
|
||||
@@ -297,27 +299,27 @@ def calculate_capacity_by_mode(
|
||||
|
||||
|
||||
def will_fit_by_mode(
|
||||
payload: Union[str, bytes, FilePayload, int],
|
||||
payload: str | bytes | FilePayload | int,
|
||||
carrier_image: bytes,
|
||||
embed_mode: str = EMBED_MODE_LSB,
|
||||
bits_per_channel: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Check if payload fits in specified mode.
|
||||
|
||||
|
||||
Args:
|
||||
payload: Message, bytes, FilePayload, or size in bytes
|
||||
carrier_image: Carrier image bytes
|
||||
embed_mode: 'lsb' or 'dct'
|
||||
bits_per_channel: For LSB mode
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with fits, capacity, usage info
|
||||
"""
|
||||
if embed_mode == EMBED_MODE_DCT:
|
||||
if not has_dct_support():
|
||||
return {'fits': False, 'error': 'scipy not available', 'mode': EMBED_MODE_DCT}
|
||||
|
||||
|
||||
if isinstance(payload, int):
|
||||
payload_size = payload
|
||||
elif isinstance(payload, str):
|
||||
@@ -326,16 +328,16 @@ def will_fit_by_mode(
|
||||
payload_size = len(payload.data)
|
||||
else:
|
||||
payload_size = len(payload)
|
||||
|
||||
|
||||
estimated_size = payload_size + ENCRYPTION_OVERHEAD + 190 # padding estimate
|
||||
|
||||
|
||||
dct_mod = _get_dct_module()
|
||||
fits = dct_mod.will_fit_dct(estimated_size, carrier_image)
|
||||
capacity_info = dct_mod.calculate_dct_capacity(carrier_image)
|
||||
capacity = capacity_info.usable_capacity_bytes
|
||||
|
||||
|
||||
usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0
|
||||
|
||||
|
||||
return {
|
||||
'fits': fits,
|
||||
'payload_size': payload_size,
|
||||
@@ -351,7 +353,7 @@ def will_fit_by_mode(
|
||||
def get_available_modes() -> dict:
|
||||
"""
|
||||
Get available embedding modes and their status.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict mapping mode name to availability info
|
||||
"""
|
||||
@@ -375,10 +377,10 @@ def get_available_modes() -> dict:
|
||||
def compare_modes(image_data: bytes) -> dict:
|
||||
"""
|
||||
Compare embedding modes for a carrier image.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Carrier image bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with comparison of LSB vs DCT modes
|
||||
"""
|
||||
@@ -387,9 +389,9 @@ def compare_modes(image_data: bytes) -> dict:
|
||||
width, height = img.size
|
||||
finally:
|
||||
img.close()
|
||||
|
||||
|
||||
lsb_bytes = calculate_capacity(image_data, 1)
|
||||
|
||||
|
||||
if has_dct_support():
|
||||
dct_mod = _get_dct_module()
|
||||
dct_info = dct_mod.calculate_dct_capacity(image_data)
|
||||
@@ -399,7 +401,7 @@ def compare_modes(image_data: bytes) -> dict:
|
||||
safe_blocks = (height // 8) * (width // 8)
|
||||
dct_bytes = (safe_blocks * 16) // 8 # Estimated
|
||||
dct_available = False
|
||||
|
||||
|
||||
return {
|
||||
'width': width,
|
||||
'height': height,
|
||||
@@ -424,62 +426,62 @@ def compare_modes(image_data: bytes) -> dict:
|
||||
# =============================================================================
|
||||
|
||||
@debug.time
|
||||
def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> List[int]:
|
||||
def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list[int]:
|
||||
"""
|
||||
Generate pseudo-random pixel indices for embedding.
|
||||
|
||||
|
||||
Uses ChaCha20 as a CSPRNG seeded by the key to deterministically
|
||||
select which pixels will hold hidden data.
|
||||
"""
|
||||
debug.validate(len(key) == 32, f"Pixel key must be 32 bytes, got {len(key)}")
|
||||
debug.validate(num_pixels > 0, f"Number of pixels must be positive, got {num_pixels}")
|
||||
debug.validate(num_needed > 0, f"Number needed must be positive, got {num_needed}")
|
||||
debug.validate(num_needed <= num_pixels,
|
||||
debug.validate(num_needed <= num_pixels,
|
||||
f"Cannot select {num_needed} pixels from {num_pixels} available")
|
||||
|
||||
|
||||
debug.print(f"Generating {num_needed} pixel indices from {num_pixels} total pixels")
|
||||
|
||||
|
||||
if num_needed >= num_pixels // 2:
|
||||
debug.print(f"Using full shuffle (needed {num_needed}/{num_pixels} pixels)")
|
||||
nonce = b'\x00' * 16
|
||||
cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
|
||||
indices = list(range(num_pixels))
|
||||
random_bytes = encryptor.update(b'\x00' * (num_pixels * 4))
|
||||
|
||||
|
||||
for i in range(num_pixels - 1, 0, -1):
|
||||
j_bytes = random_bytes[(num_pixels - 1 - i) * 4:(num_pixels - i) * 4]
|
||||
j = int.from_bytes(j_bytes, 'big') % (i + 1)
|
||||
indices[i], indices[j] = indices[j], indices[i]
|
||||
|
||||
|
||||
selected = indices[:num_needed]
|
||||
debug.print(f"Generated {len(selected)} indices via shuffle")
|
||||
return selected
|
||||
|
||||
|
||||
debug.print(f"Using optimized selection (needed {num_needed}/{num_pixels} pixels)")
|
||||
selected = []
|
||||
used = set()
|
||||
|
||||
|
||||
nonce = b'\x00' * 16
|
||||
cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
|
||||
bytes_needed = (num_needed * 2) * 4
|
||||
random_bytes = encryptor.update(b'\x00' * bytes_needed)
|
||||
|
||||
|
||||
byte_offset = 0
|
||||
collisions = 0
|
||||
while len(selected) < num_needed and byte_offset < len(random_bytes) - 4:
|
||||
idx = int.from_bytes(random_bytes[byte_offset:byte_offset + 4], 'big') % num_pixels
|
||||
byte_offset += 4
|
||||
|
||||
|
||||
if idx not in used:
|
||||
used.add(idx)
|
||||
selected.append(idx)
|
||||
else:
|
||||
collisions += 1
|
||||
|
||||
|
||||
if len(selected) < num_needed:
|
||||
debug.print(f"Need {num_needed - len(selected)} more indices, generating...")
|
||||
extra_needed = num_needed - len(selected)
|
||||
@@ -491,7 +493,7 @@ def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> List
|
||||
selected.append(idx)
|
||||
if len(selected) == num_needed:
|
||||
break
|
||||
|
||||
|
||||
debug.print(f"Generated {len(selected)} indices with {collisions} collisions")
|
||||
debug.validate(len(selected) == num_needed,
|
||||
f"Failed to generate enough indices: {len(selected)}/{num_needed}")
|
||||
@@ -508,14 +510,14 @@ def embed_in_image(
|
||||
image_data: bytes,
|
||||
pixel_key: bytes,
|
||||
bits_per_channel: int = 1,
|
||||
output_format: Optional[str] = None,
|
||||
output_format: str | None = None,
|
||||
embed_mode: str = EMBED_MODE_LSB,
|
||||
dct_output_format: str = DCT_OUTPUT_PNG,
|
||||
dct_color_mode: str = 'grayscale',
|
||||
) -> Tuple[bytes, Union[EmbedStats, 'DCTEmbedStats'], str]:
|
||||
) -> tuple[bytes, Union[EmbedStats, 'DCTEmbedStats'], str]:
|
||||
"""
|
||||
Embed data into an image using specified mode.
|
||||
|
||||
|
||||
Args:
|
||||
data: Data to embed (encrypted payload)
|
||||
image_data: Carrier image bytes
|
||||
@@ -525,19 +527,19 @@ def embed_in_image(
|
||||
embed_mode: 'lsb' (default) or 'dct'
|
||||
dct_output_format: For DCT mode - 'png' (lossless) or 'jpeg' (smaller)
|
||||
dct_color_mode: For DCT mode - 'grayscale' (default) or 'color' (preserves colors)
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (stego image bytes, stats, file extension)
|
||||
|
||||
|
||||
Raises:
|
||||
CapacityError: If data won't fit
|
||||
EmbeddingError: If embedding fails
|
||||
ImportError: If DCT mode requested but scipy unavailable
|
||||
"""
|
||||
debug.print(f"embed_in_image: mode={embed_mode}, data={len(data)} bytes")
|
||||
debug.validate(embed_mode in VALID_EMBED_MODES,
|
||||
debug.validate(embed_mode in VALID_EMBED_MODES,
|
||||
f"Invalid embed_mode: {embed_mode}. Use 'lsb' or 'dct'")
|
||||
|
||||
|
||||
# DCT MODE
|
||||
if embed_mode == EMBED_MODE_DCT:
|
||||
if not has_dct_support():
|
||||
@@ -545,38 +547,38 @@ def embed_in_image(
|
||||
"scipy is required for DCT embedding mode. "
|
||||
"Install with: pip install scipy"
|
||||
)
|
||||
|
||||
|
||||
# Validate DCT output format
|
||||
if dct_output_format not in (DCT_OUTPUT_PNG, DCT_OUTPUT_JPEG):
|
||||
debug.print(f"Invalid dct_output_format '{dct_output_format}', defaulting to PNG")
|
||||
dct_output_format = DCT_OUTPUT_PNG
|
||||
|
||||
|
||||
# Validate DCT color mode (v3.0.1)
|
||||
if dct_color_mode not in ('grayscale', 'color'):
|
||||
debug.print(f"Invalid dct_color_mode '{dct_color_mode}', defaulting to grayscale")
|
||||
dct_color_mode = 'grayscale'
|
||||
|
||||
|
||||
dct_mod = _get_dct_module()
|
||||
|
||||
|
||||
# Pass output_format and color_mode to DCT module (v3.0.1)
|
||||
stego_bytes, dct_stats = dct_mod.embed_in_dct(
|
||||
data,
|
||||
image_data,
|
||||
data,
|
||||
image_data,
|
||||
pixel_key,
|
||||
output_format=dct_output_format,
|
||||
color_mode=dct_color_mode,
|
||||
)
|
||||
|
||||
|
||||
# Determine extension based on output format
|
||||
if dct_output_format == DCT_OUTPUT_JPEG:
|
||||
ext = 'jpg'
|
||||
else:
|
||||
ext = 'png'
|
||||
|
||||
|
||||
debug.print(f"DCT embedding complete: {dct_output_format.upper()} output, "
|
||||
f"color_mode={dct_color_mode}, ext={ext}")
|
||||
return stego_bytes, dct_stats, ext
|
||||
|
||||
|
||||
# LSB MODE
|
||||
return _embed_lsb(data, image_data, pixel_key, bits_per_channel, output_format)
|
||||
|
||||
@@ -586,75 +588,75 @@ def _embed_lsb(
|
||||
image_data: bytes,
|
||||
pixel_key: bytes,
|
||||
bits_per_channel: int = 1,
|
||||
output_format: Optional[str] = None,
|
||||
) -> Tuple[bytes, EmbedStats, str]:
|
||||
output_format: str | None = None,
|
||||
) -> tuple[bytes, EmbedStats, str]:
|
||||
"""
|
||||
Embed data using LSB steganography (internal implementation).
|
||||
"""
|
||||
debug.print(f"LSB embedding {len(data)} bytes into image")
|
||||
debug.data(pixel_key, "Pixel key for embedding")
|
||||
debug.validate(bits_per_channel in (1, 2),
|
||||
debug.validate(bits_per_channel in (1, 2),
|
||||
f"bits_per_channel must be 1 or 2, got {bits_per_channel}")
|
||||
debug.validate(len(pixel_key) == 32,
|
||||
f"Pixel key must be 32 bytes, got {len(pixel_key)}")
|
||||
|
||||
|
||||
img_file = None
|
||||
img = None
|
||||
stego_img = None
|
||||
|
||||
|
||||
try:
|
||||
img_file = Image.open(io.BytesIO(image_data))
|
||||
input_format = img_file.format
|
||||
|
||||
|
||||
debug.print(f"Carrier image: {img_file.size[0]}x{img_file.size[1]}, format: {input_format}")
|
||||
|
||||
|
||||
img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy()
|
||||
if img_file.mode != 'RGB':
|
||||
debug.print(f"Converting image from {img_file.mode} to RGB")
|
||||
|
||||
|
||||
pixels = list(img.getdata())
|
||||
num_pixels = len(pixels)
|
||||
|
||||
|
||||
bits_per_pixel = 3 * bits_per_channel
|
||||
max_bytes = (num_pixels * bits_per_pixel) // 8
|
||||
|
||||
|
||||
debug.print(f"Image capacity: {max_bytes} bytes at {bits_per_channel} bit(s)/channel")
|
||||
|
||||
|
||||
data_with_len = struct.pack('>I', len(data)) + data
|
||||
|
||||
|
||||
if len(data_with_len) > max_bytes:
|
||||
debug.print(f"Capacity error: need {len(data_with_len)}, have {max_bytes}")
|
||||
raise CapacityError(len(data_with_len), max_bytes)
|
||||
|
||||
|
||||
debug.print(f"Total data to embed: {len(data_with_len)} bytes "
|
||||
f"({len(data_with_len)/max_bytes*100:.1f}% of capacity)")
|
||||
|
||||
|
||||
binary_data = ''.join(format(b, '08b') for b in data_with_len)
|
||||
pixels_needed = (len(binary_data) + bits_per_pixel - 1) // bits_per_pixel
|
||||
|
||||
|
||||
debug.print(f"Need {pixels_needed} pixels to embed {len(binary_data)} bits")
|
||||
|
||||
|
||||
selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed)
|
||||
|
||||
|
||||
new_pixels = list(pixels)
|
||||
clear_mask = 0xFF ^ ((1 << bits_per_channel) - 1)
|
||||
|
||||
|
||||
bit_idx = 0
|
||||
modified_pixels = 0
|
||||
|
||||
|
||||
for pixel_idx in selected_indices:
|
||||
if bit_idx >= len(binary_data):
|
||||
break
|
||||
|
||||
|
||||
r, g, b = new_pixels[pixel_idx]
|
||||
modified = False
|
||||
|
||||
|
||||
for channel_idx, channel_val in enumerate([r, g, b]):
|
||||
if bit_idx >= len(binary_data):
|
||||
break
|
||||
bits = binary_data[bit_idx:bit_idx + bits_per_channel].ljust(bits_per_channel, '0')
|
||||
new_val = (channel_val & clear_mask) | int(bits, 2)
|
||||
|
||||
|
||||
if channel_val != new_val:
|
||||
modified = True
|
||||
if channel_idx == 0:
|
||||
@@ -663,18 +665,18 @@ def _embed_lsb(
|
||||
g = new_val
|
||||
else:
|
||||
b = new_val
|
||||
|
||||
|
||||
bit_idx += bits_per_channel
|
||||
|
||||
|
||||
if modified:
|
||||
new_pixels[pixel_idx] = (r, g, b)
|
||||
modified_pixels += 1
|
||||
|
||||
|
||||
debug.print(f"Modified {modified_pixels} pixels (out of {len(selected_indices)} selected)")
|
||||
|
||||
|
||||
stego_img = Image.new('RGB', img.size)
|
||||
stego_img.putdata(new_pixels)
|
||||
|
||||
|
||||
if output_format:
|
||||
out_fmt = output_format.upper()
|
||||
out_ext = FORMAT_TO_EXT.get(out_fmt, 'png')
|
||||
@@ -682,21 +684,21 @@ def _embed_lsb(
|
||||
else:
|
||||
out_fmt, out_ext = get_output_format(input_format)
|
||||
debug.print(f"Auto-selected output format: {out_fmt}")
|
||||
|
||||
|
||||
output = io.BytesIO()
|
||||
stego_img.save(output, out_fmt)
|
||||
output.seek(0)
|
||||
|
||||
|
||||
stats = EmbedStats(
|
||||
pixels_modified=modified_pixels,
|
||||
total_pixels=num_pixels,
|
||||
capacity_used=len(data_with_len) / max_bytes,
|
||||
bytes_embedded=len(data_with_len)
|
||||
)
|
||||
|
||||
|
||||
debug.print(f"LSB embedding complete: {out_fmt} image, {len(output.getvalue())} bytes")
|
||||
return output.getvalue(), stats, out_ext
|
||||
|
||||
|
||||
except CapacityError:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -722,50 +724,50 @@ def extract_from_image(
|
||||
pixel_key: bytes,
|
||||
bits_per_channel: int = 1,
|
||||
embed_mode: str = EMBED_MODE_AUTO,
|
||||
) -> Optional[bytes]:
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Extract hidden data from a stego image.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Stego image bytes
|
||||
pixel_key: Key for pixel/coefficient selection (must match encoding)
|
||||
bits_per_channel: Bits per channel (LSB mode only)
|
||||
embed_mode: 'auto' (try both), 'lsb', or 'dct'
|
||||
|
||||
|
||||
Returns:
|
||||
Extracted data bytes, or None if extraction fails
|
||||
"""
|
||||
debug.print(f"extract_from_image: mode={embed_mode}")
|
||||
|
||||
|
||||
# AUTO MODE: Try LSB first, then DCT
|
||||
if embed_mode == EMBED_MODE_AUTO:
|
||||
result = _extract_lsb(image_data, pixel_key, bits_per_channel)
|
||||
if result is not None:
|
||||
debug.print("Auto-detect: LSB extraction succeeded")
|
||||
return result
|
||||
|
||||
|
||||
if has_dct_support():
|
||||
debug.print("Auto-detect: LSB failed, trying DCT")
|
||||
result = _extract_dct(image_data, pixel_key)
|
||||
if result is not None:
|
||||
debug.print("Auto-detect: DCT extraction succeeded")
|
||||
return result
|
||||
|
||||
|
||||
debug.print("Auto-detect: All modes failed")
|
||||
return None
|
||||
|
||||
|
||||
# EXPLICIT DCT MODE
|
||||
elif embed_mode == EMBED_MODE_DCT:
|
||||
if not has_dct_support():
|
||||
raise ImportError("scipy required for DCT mode")
|
||||
return _extract_dct(image_data, pixel_key)
|
||||
|
||||
|
||||
# EXPLICIT LSB MODE
|
||||
else:
|
||||
return _extract_lsb(image_data, pixel_key, bits_per_channel)
|
||||
|
||||
|
||||
def _extract_dct(image_data: bytes, pixel_key: bytes) -> Optional[bytes]:
|
||||
def _extract_dct(image_data: bytes, pixel_key: bytes) -> bytes | None:
|
||||
"""Extract using DCT mode."""
|
||||
try:
|
||||
dct_mod = _get_dct_module()
|
||||
@@ -779,7 +781,7 @@ def _extract_lsb(
|
||||
image_data: bytes,
|
||||
pixel_key: bytes,
|
||||
bits_per_channel: int = 1
|
||||
) -> Optional[bytes]:
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Extract using LSB mode (internal implementation).
|
||||
"""
|
||||
@@ -787,82 +789,82 @@ def _extract_lsb(
|
||||
debug.data(pixel_key, "Pixel key for extraction")
|
||||
debug.validate(bits_per_channel in (1, 2),
|
||||
f"bits_per_channel must be 1 or 2, got {bits_per_channel}")
|
||||
|
||||
|
||||
img_file = None
|
||||
img = None
|
||||
|
||||
|
||||
try:
|
||||
img_file = Image.open(io.BytesIO(image_data))
|
||||
debug.print(f"Image: {img_file.size[0]}x{img_file.size[1]}, format: {img_file.format}")
|
||||
|
||||
|
||||
img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy()
|
||||
if img_file.mode != 'RGB':
|
||||
debug.print(f"Converting image from {img_file.mode} to RGB")
|
||||
|
||||
|
||||
pixels = list(img.getdata())
|
||||
num_pixels = len(pixels)
|
||||
bits_per_pixel = 3 * bits_per_channel
|
||||
|
||||
|
||||
debug.print(f"Image has {num_pixels} pixels, {bits_per_pixel} bits/pixel")
|
||||
|
||||
|
||||
initial_pixels = (32 + bits_per_pixel - 1) // bits_per_pixel + 10
|
||||
debug.print(f"Extracting initial {initial_pixels} pixels to find length")
|
||||
|
||||
|
||||
initial_indices = generate_pixel_indices(pixel_key, num_pixels, initial_pixels)
|
||||
|
||||
|
||||
binary_data = ''
|
||||
for pixel_idx in initial_indices:
|
||||
r, g, b = pixels[pixel_idx]
|
||||
for channel in [r, g, b]:
|
||||
for bit_pos in range(bits_per_channel - 1, -1, -1):
|
||||
binary_data += str((channel >> bit_pos) & 1)
|
||||
|
||||
|
||||
try:
|
||||
length_bits = binary_data[:32]
|
||||
if len(length_bits) < 32:
|
||||
debug.print(f"Not enough bits for length: {len(length_bits)}/32")
|
||||
return None
|
||||
|
||||
|
||||
data_length = struct.unpack('>I', int(length_bits, 2).to_bytes(4, 'big'))[0]
|
||||
debug.print(f"Extracted length: {data_length} bytes")
|
||||
except Exception as e:
|
||||
debug.print(f"Failed to parse length: {e}")
|
||||
return None
|
||||
|
||||
|
||||
max_possible = (num_pixels * bits_per_pixel) // 8 - 4
|
||||
if data_length > max_possible or data_length < 10:
|
||||
debug.print(f"Invalid data length: {data_length} (max possible: {max_possible})")
|
||||
return None
|
||||
|
||||
|
||||
total_bits = (4 + data_length) * 8
|
||||
pixels_needed = (total_bits + bits_per_pixel - 1) // bits_per_pixel
|
||||
|
||||
|
||||
debug.print(f"Need {pixels_needed} pixels to extract {data_length} bytes")
|
||||
|
||||
|
||||
selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed)
|
||||
|
||||
|
||||
binary_data = ''
|
||||
for pixel_idx in selected_indices:
|
||||
r, g, b = pixels[pixel_idx]
|
||||
for channel in [r, g, b]:
|
||||
for bit_pos in range(bits_per_channel - 1, -1, -1):
|
||||
binary_data += str((channel >> bit_pos) & 1)
|
||||
|
||||
|
||||
data_bits = binary_data[32:32 + (data_length * 8)]
|
||||
|
||||
|
||||
if len(data_bits) < data_length * 8:
|
||||
debug.print(f"Insufficient bits: {len(data_bits)} < {data_length * 8}")
|
||||
return None
|
||||
|
||||
|
||||
data_bytes = bytearray()
|
||||
for i in range(0, len(data_bits), 8):
|
||||
byte_bits = data_bits[i:i + 8]
|
||||
if len(byte_bits) == 8:
|
||||
data_bytes.append(int(byte_bits, 2))
|
||||
|
||||
|
||||
debug.print(f"LSB successfully extracted {len(data_bytes)} bytes")
|
||||
return bytes(data_bytes)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
debug.exception(e, "extract_lsb")
|
||||
return None
|
||||
@@ -878,7 +880,7 @@ def _extract_lsb(
|
||||
# UTILITY FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
def get_image_dimensions(image_data: bytes) -> Tuple[int, int]:
|
||||
def get_image_dimensions(image_data: bytes) -> tuple[int, int]:
|
||||
"""Get image dimensions without loading full image."""
|
||||
debug.validate(len(image_data) > 0, "Image data cannot be empty")
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
@@ -890,7 +892,7 @@ def get_image_dimensions(image_data: bytes) -> Tuple[int, int]:
|
||||
img.close()
|
||||
|
||||
|
||||
def get_image_format(image_data: bytes) -> Optional[str]:
|
||||
def get_image_format(image_data: bytes) -> str | None:
|
||||
"""Get image format (PIL format string like 'PNG', 'JPEG')."""
|
||||
try:
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
@@ -9,9 +9,8 @@ import os
|
||||
import random
|
||||
import secrets
|
||||
import shutil
|
||||
from datetime import date, datetime
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -22,98 +21,98 @@ from .debug import debug
|
||||
def strip_image_metadata(image_data: bytes, output_format: str = 'PNG') -> bytes:
|
||||
"""
|
||||
Remove all metadata (EXIF, ICC profiles, etc.) from an image.
|
||||
|
||||
|
||||
Creates a fresh image with only pixel data - no EXIF, GPS coordinates,
|
||||
camera info, timestamps, or other potentially sensitive metadata.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
output_format: Output format ('PNG', 'BMP', 'TIFF')
|
||||
|
||||
|
||||
Returns:
|
||||
Clean image bytes with no metadata
|
||||
|
||||
|
||||
Example:
|
||||
>>> clean = strip_image_metadata(photo_bytes)
|
||||
>>> # EXIF data is now removed
|
||||
"""
|
||||
debug.print(f"Stripping metadata, output format: {output_format}")
|
||||
|
||||
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
|
||||
# Convert to RGB if needed (handles RGBA, P, L, etc.)
|
||||
if img.mode not in ('RGB', 'RGBA'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
|
||||
# Create fresh image - this discards all metadata
|
||||
clean = Image.new(img.mode, img.size)
|
||||
clean.putdata(list(img.getdata()))
|
||||
|
||||
|
||||
output = io.BytesIO()
|
||||
clean.save(output, output_format.upper())
|
||||
output.seek(0)
|
||||
|
||||
|
||||
debug.print(f"Metadata stripped: {len(image_data)} -> {len(output.getvalue())} bytes")
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def generate_filename(
|
||||
date_str: Optional[str] = None,
|
||||
date_str: str | None = None,
|
||||
prefix: str = "",
|
||||
extension: str = "png"
|
||||
) -> str:
|
||||
"""
|
||||
Generate a filename for stego images.
|
||||
|
||||
|
||||
Format: {prefix}{random}_{YYYYMMDD}.{extension}
|
||||
|
||||
|
||||
Args:
|
||||
date_str: Date string (YYYY-MM-DD), defaults to today
|
||||
prefix: Optional prefix
|
||||
extension: File extension without dot (default: 'png')
|
||||
|
||||
|
||||
Returns:
|
||||
Filename string
|
||||
|
||||
|
||||
Example:
|
||||
>>> generate_filename("2023-12-25", "secret_", "png")
|
||||
"secret_a1b2c3d4_20231225.png"
|
||||
"""
|
||||
debug.validate(bool(extension) and '.' not in extension,
|
||||
f"Extension must not contain dot, got '{extension}'")
|
||||
|
||||
|
||||
if date_str is None:
|
||||
date_str = date.today().isoformat()
|
||||
|
||||
|
||||
date_compact = date_str.replace('-', '')
|
||||
random_hex = secrets.token_hex(4)
|
||||
|
||||
|
||||
# Ensure extension doesn't have a leading dot
|
||||
extension = extension.lstrip('.')
|
||||
|
||||
|
||||
filename = f"{prefix}{random_hex}_{date_compact}.{extension}"
|
||||
debug.print(f"Generated filename: {filename}")
|
||||
return filename
|
||||
|
||||
|
||||
def parse_date_from_filename(filename: str) -> Optional[str]:
|
||||
def parse_date_from_filename(filename: str) -> str | None:
|
||||
"""
|
||||
Extract date from a stego filename.
|
||||
|
||||
|
||||
Looks for patterns like _20251227 or _2025-12-27
|
||||
|
||||
|
||||
Args:
|
||||
filename: Filename to parse
|
||||
|
||||
|
||||
Returns:
|
||||
Date string (YYYY-MM-DD) or None
|
||||
|
||||
|
||||
Example:
|
||||
>>> parse_date_from_filename("secret_a1b2c3d4_20231225.png")
|
||||
"2023-12-25"
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# Try YYYYMMDD format
|
||||
match = re.search(r'_(\d{4})(\d{2})(\d{2})(?:\.|$)', filename)
|
||||
if match:
|
||||
@@ -121,7 +120,7 @@ def parse_date_from_filename(filename: str) -> Optional[str]:
|
||||
date_str = f"{year}-{month}-{day}"
|
||||
debug.print(f"Parsed date (compact): {date_str}")
|
||||
return date_str
|
||||
|
||||
|
||||
# Try YYYY-MM-DD format
|
||||
match = re.search(r'_(\d{4})-(\d{2})-(\d{2})(?:\.|$)', filename)
|
||||
if match:
|
||||
@@ -129,7 +128,7 @@ def parse_date_from_filename(filename: str) -> Optional[str]:
|
||||
date_str = f"{year}-{month}-{day}"
|
||||
debug.print(f"Parsed date (dashed): {date_str}")
|
||||
return date_str
|
||||
|
||||
|
||||
debug.print(f"No date found in filename: {filename}")
|
||||
return None
|
||||
|
||||
@@ -137,20 +136,20 @@ def parse_date_from_filename(filename: str) -> Optional[str]:
|
||||
def get_day_from_date(date_str: str) -> str:
|
||||
"""
|
||||
Get day of week name from date string.
|
||||
|
||||
|
||||
Args:
|
||||
date_str: Date string (YYYY-MM-DD)
|
||||
|
||||
|
||||
Returns:
|
||||
Day name (e.g., "Monday")
|
||||
|
||||
|
||||
Example:
|
||||
>>> get_day_from_date("2023-12-25")
|
||||
"Monday"
|
||||
"""
|
||||
debug.validate(len(date_str) == 10 and date_str[4] == '-' and date_str[7] == '-',
|
||||
f"Invalid date format: {date_str}, expected YYYY-MM-DD")
|
||||
|
||||
|
||||
try:
|
||||
year, month, day = map(int, date_str.split('-'))
|
||||
d = date(year, month, day)
|
||||
@@ -165,10 +164,10 @@ def get_day_from_date(date_str: str) -> str:
|
||||
def get_today_date() -> str:
|
||||
"""
|
||||
Get today's date as YYYY-MM-DD.
|
||||
|
||||
|
||||
Returns:
|
||||
Today's date string
|
||||
|
||||
|
||||
Example:
|
||||
>>> get_today_date()
|
||||
"2023-12-25"
|
||||
@@ -181,10 +180,10 @@ def get_today_date() -> str:
|
||||
def get_today_day() -> str:
|
||||
"""
|
||||
Get today's day name.
|
||||
|
||||
|
||||
Returns:
|
||||
Today's day name
|
||||
|
||||
|
||||
Example:
|
||||
>>> get_today_day()
|
||||
"Monday"
|
||||
@@ -197,43 +196,43 @@ def get_today_day() -> str:
|
||||
class SecureDeleter:
|
||||
"""
|
||||
Securely delete files by overwriting with random data.
|
||||
|
||||
|
||||
Implements multi-pass overwriting before deletion.
|
||||
|
||||
|
||||
Example:
|
||||
>>> deleter = SecureDeleter("secret.txt", passes=3)
|
||||
>>> deleter.execute()
|
||||
"""
|
||||
|
||||
def __init__(self, path: Union[str, Path], passes: int = 7):
|
||||
|
||||
def __init__(self, path: str | Path, passes: int = 7):
|
||||
"""
|
||||
Initialize secure deleter.
|
||||
|
||||
|
||||
Args:
|
||||
path: Path to file or directory
|
||||
passes: Number of overwrite passes
|
||||
"""
|
||||
debug.validate(passes > 0, f"Passes must be positive, got {passes}")
|
||||
|
||||
|
||||
self.path = Path(path)
|
||||
self.passes = passes
|
||||
debug.print(f"SecureDeleter initialized for {self.path} with {passes} passes")
|
||||
|
||||
|
||||
def _overwrite_file(self, file_path: Path) -> None:
|
||||
"""Overwrite file with random data multiple times."""
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
debug.print(f"File does not exist or is not a file: {file_path}")
|
||||
return
|
||||
|
||||
|
||||
length = file_path.stat().st_size
|
||||
debug.print(f"Overwriting file {file_path} ({length} bytes)")
|
||||
|
||||
|
||||
if length == 0:
|
||||
debug.print("File is empty, nothing to overwrite")
|
||||
return
|
||||
|
||||
|
||||
patterns = [b'\x00', b'\xFF', bytes([random.randint(0, 255)])]
|
||||
|
||||
|
||||
for pass_num in range(self.passes):
|
||||
debug.print(f"Overwrite pass {pass_num + 1}/{self.passes}")
|
||||
with open(file_path, 'r+b') as f:
|
||||
@@ -245,13 +244,13 @@ class SecureDeleter:
|
||||
chunk = min(chunk_size, length - offset)
|
||||
f.write(pattern * (chunk // len(pattern)))
|
||||
f.write(pattern[:chunk % len(pattern)])
|
||||
|
||||
|
||||
# Final pass with random data
|
||||
f.seek(0)
|
||||
f.write(os.urandom(length))
|
||||
|
||||
|
||||
debug.print(f"Completed {self.passes} overwrite passes")
|
||||
|
||||
|
||||
def delete_file(self) -> None:
|
||||
"""Securely delete a single file."""
|
||||
if self.path.is_file():
|
||||
@@ -261,28 +260,28 @@ class SecureDeleter:
|
||||
debug.print(f"File deleted: {self.path}")
|
||||
else:
|
||||
debug.print(f"Not a file: {self.path}")
|
||||
|
||||
|
||||
def delete_directory(self) -> None:
|
||||
"""Securely delete a directory and all contents."""
|
||||
if not self.path.is_dir():
|
||||
debug.print(f"Not a directory: {self.path}")
|
||||
return
|
||||
|
||||
|
||||
debug.print(f"Securely deleting directory: {self.path}")
|
||||
|
||||
|
||||
# First, securely overwrite all files
|
||||
file_count = 0
|
||||
for file_path in self.path.rglob('*'):
|
||||
if file_path.is_file():
|
||||
self._overwrite_file(file_path)
|
||||
file_count += 1
|
||||
|
||||
|
||||
debug.print(f"Overwrote {file_count} files")
|
||||
|
||||
|
||||
# Then remove the directory tree
|
||||
shutil.rmtree(self.path)
|
||||
debug.print(f"Directory deleted: {self.path}")
|
||||
|
||||
|
||||
def execute(self) -> None:
|
||||
"""Securely delete the path (file or directory)."""
|
||||
debug.print(f"Executing secure deletion: {self.path}")
|
||||
@@ -294,14 +293,14 @@ class SecureDeleter:
|
||||
debug.print(f"Path does not exist: {self.path}")
|
||||
|
||||
|
||||
def secure_delete(path: Union[str, Path], passes: int = 7) -> None:
|
||||
def secure_delete(path: str | Path, passes: int = 7) -> None:
|
||||
"""
|
||||
Convenience function for secure deletion.
|
||||
|
||||
|
||||
Args:
|
||||
path: Path to file or directory
|
||||
passes: Number of overwrite passes
|
||||
|
||||
|
||||
Example:
|
||||
>>> secure_delete("secret.txt", passes=3)
|
||||
"""
|
||||
@@ -312,19 +311,19 @@ def secure_delete(path: Union[str, Path], passes: int = 7) -> None:
|
||||
def format_file_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Format file size for display.
|
||||
|
||||
|
||||
Args:
|
||||
size_bytes: Size in bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Human-readable string (e.g., "1.5 MB")
|
||||
|
||||
|
||||
Example:
|
||||
>>> format_file_size(1500000)
|
||||
"1.5 MB"
|
||||
"""
|
||||
debug.validate(size_bytes >= 0, f"File size cannot be negative: {size_bytes}")
|
||||
|
||||
|
||||
size: float = float(size_bytes)
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if size < 1024:
|
||||
@@ -338,13 +337,13 @@ def format_file_size(size_bytes: int) -> str:
|
||||
def format_number(n: int) -> str:
|
||||
"""
|
||||
Format number with commas.
|
||||
|
||||
|
||||
Args:
|
||||
n: Integer to format
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted string
|
||||
|
||||
|
||||
Example:
|
||||
>>> format_number(1234567)
|
||||
"1,234,567"
|
||||
@@ -356,15 +355,15 @@ def format_number(n: int) -> str:
|
||||
def clamp(value: int, min_val: int, max_val: int) -> int:
|
||||
"""
|
||||
Clamp value to range.
|
||||
|
||||
|
||||
Args:
|
||||
value: Value to clamp
|
||||
min_val: Minimum allowed value
|
||||
max_val: Maximum allowed value
|
||||
|
||||
|
||||
Returns:
|
||||
Clamped value
|
||||
|
||||
|
||||
Example:
|
||||
>>> clamp(15, 0, 10)
|
||||
10
|
||||
|
||||
@@ -10,40 +10,50 @@ Changes in v3.2.0:
|
||||
"""
|
||||
|
||||
import io
|
||||
from typing import Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .constants import (
|
||||
MIN_PIN_LENGTH, MAX_PIN_LENGTH,
|
||||
MAX_MESSAGE_SIZE, MAX_FILE_PAYLOAD_SIZE, MAX_IMAGE_PIXELS, MAX_FILE_SIZE,
|
||||
MIN_RSA_BITS, MIN_KEY_PASSWORD_LENGTH,
|
||||
ALLOWED_IMAGE_EXTENSIONS, ALLOWED_KEY_EXTENSIONS,
|
||||
MIN_PASSPHRASE_WORDS, RECOMMENDED_PASSPHRASE_WORDS,
|
||||
EMBED_MODE_LSB, EMBED_MODE_DCT, EMBED_MODE_AUTO,
|
||||
ALLOWED_IMAGE_EXTENSIONS,
|
||||
ALLOWED_KEY_EXTENSIONS,
|
||||
EMBED_MODE_AUTO,
|
||||
EMBED_MODE_DCT,
|
||||
EMBED_MODE_LSB,
|
||||
MAX_FILE_PAYLOAD_SIZE,
|
||||
MAX_FILE_SIZE,
|
||||
MAX_IMAGE_PIXELS,
|
||||
MAX_MESSAGE_SIZE,
|
||||
MAX_PIN_LENGTH,
|
||||
MIN_KEY_PASSWORD_LENGTH,
|
||||
MIN_PASSPHRASE_WORDS,
|
||||
MIN_PIN_LENGTH,
|
||||
MIN_RSA_BITS,
|
||||
RECOMMENDED_PASSPHRASE_WORDS,
|
||||
)
|
||||
from .models import ValidationResult, FilePayload
|
||||
from .exceptions import (
|
||||
ValidationError, PinValidationError, MessageValidationError,
|
||||
ImageValidationError, KeyValidationError, SecurityFactorError,
|
||||
FileTooLargeError, UnsupportedFileTypeError,
|
||||
ImageValidationError,
|
||||
KeyValidationError,
|
||||
MessageValidationError,
|
||||
PinValidationError,
|
||||
SecurityFactorError,
|
||||
)
|
||||
from .keygen import load_rsa_key
|
||||
from .models import FilePayload, ValidationResult
|
||||
|
||||
|
||||
def validate_pin(pin: str, required: bool = False) -> ValidationResult:
|
||||
"""
|
||||
Validate PIN format.
|
||||
|
||||
|
||||
Rules:
|
||||
- 6-9 digits only
|
||||
- Cannot start with zero
|
||||
- Empty is OK if not required
|
||||
|
||||
|
||||
Args:
|
||||
pin: PIN string to validate
|
||||
required: Whether PIN is required
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
@@ -51,83 +61,83 @@ def validate_pin(pin: str, required: bool = False) -> ValidationResult:
|
||||
if required:
|
||||
return ValidationResult.error("PIN is required")
|
||||
return ValidationResult.ok()
|
||||
|
||||
|
||||
if not pin.isdigit():
|
||||
return ValidationResult.error("PIN must contain only digits")
|
||||
|
||||
|
||||
if len(pin) < MIN_PIN_LENGTH or len(pin) > MAX_PIN_LENGTH:
|
||||
return ValidationResult.error(
|
||||
f"PIN must be {MIN_PIN_LENGTH}-{MAX_PIN_LENGTH} digits"
|
||||
)
|
||||
|
||||
|
||||
if pin[0] == '0':
|
||||
return ValidationResult.error("PIN cannot start with zero")
|
||||
|
||||
|
||||
return ValidationResult.ok(length=len(pin))
|
||||
|
||||
|
||||
def validate_message(message: str) -> ValidationResult:
|
||||
"""
|
||||
Validate text message content and size.
|
||||
|
||||
|
||||
Args:
|
||||
message: Message text
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
if not message:
|
||||
return ValidationResult.error("Message is required")
|
||||
|
||||
|
||||
if len(message) > MAX_MESSAGE_SIZE:
|
||||
return ValidationResult.error(
|
||||
f"Message too long ({len(message):,} chars). Maximum: {MAX_MESSAGE_SIZE:,} characters"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(length=len(message))
|
||||
|
||||
|
||||
def validate_payload(payload: Union[str, bytes, FilePayload]) -> ValidationResult:
|
||||
def validate_payload(payload: str | bytes | FilePayload) -> ValidationResult:
|
||||
"""
|
||||
Validate a payload (text message, bytes, or file).
|
||||
|
||||
|
||||
Args:
|
||||
payload: Text string, raw bytes, or FilePayload
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
if isinstance(payload, str):
|
||||
return validate_message(payload)
|
||||
|
||||
|
||||
elif isinstance(payload, FilePayload):
|
||||
if not payload.data:
|
||||
return ValidationResult.error("File is empty")
|
||||
|
||||
|
||||
if len(payload.data) > MAX_FILE_PAYLOAD_SIZE:
|
||||
return ValidationResult.error(
|
||||
f"File too large ({len(payload.data):,} bytes). "
|
||||
f"Maximum: {MAX_FILE_PAYLOAD_SIZE:,} bytes ({MAX_FILE_PAYLOAD_SIZE // 1024} KB)"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(
|
||||
size=len(payload.data),
|
||||
filename=payload.filename,
|
||||
mime_type=payload.mime_type
|
||||
)
|
||||
|
||||
|
||||
elif isinstance(payload, bytes):
|
||||
if not payload:
|
||||
return ValidationResult.error("Payload is empty")
|
||||
|
||||
|
||||
if len(payload) > MAX_FILE_PAYLOAD_SIZE:
|
||||
return ValidationResult.error(
|
||||
f"Payload too large ({len(payload):,} bytes). "
|
||||
f"Maximum: {MAX_FILE_PAYLOAD_SIZE:,} bytes ({MAX_FILE_PAYLOAD_SIZE // 1024} KB)"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(size=len(payload))
|
||||
|
||||
|
||||
else:
|
||||
return ValidationResult.error(f"Invalid payload type: {type(payload)}")
|
||||
|
||||
@@ -139,18 +149,18 @@ def validate_file_payload(
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate a file for embedding.
|
||||
|
||||
|
||||
Args:
|
||||
file_data: Raw file bytes
|
||||
filename: Original filename (for display in errors)
|
||||
max_size: Maximum allowed size in bytes
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
if not file_data:
|
||||
return ValidationResult.error("File is empty")
|
||||
|
||||
|
||||
if len(file_data) > max_size:
|
||||
size_kb = len(file_data) / 1024
|
||||
max_kb = max_size / 1024
|
||||
@@ -158,7 +168,7 @@ def validate_file_payload(
|
||||
f"File '{filename or 'unnamed'}' too large ({size_kb:.1f} KB). "
|
||||
f"Maximum: {max_kb:.0f} KB"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(size=len(file_data), filename=filename)
|
||||
|
||||
|
||||
@@ -169,35 +179,35 @@ def validate_image(
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate image data and dimensions.
|
||||
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
name: Name for error messages
|
||||
check_size: Whether to check pixel dimensions
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult with width, height, pixels
|
||||
"""
|
||||
if not image_data:
|
||||
return ValidationResult.error(f"{name} is required")
|
||||
|
||||
|
||||
if len(image_data) > MAX_FILE_SIZE:
|
||||
return ValidationResult.error(
|
||||
f"{name} too large ({len(image_data):,} bytes). Maximum: {MAX_FILE_SIZE:,} bytes"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
width, height = img.size
|
||||
num_pixels = width * height
|
||||
|
||||
|
||||
if check_size and num_pixels > MAX_IMAGE_PIXELS:
|
||||
max_dim = int(MAX_IMAGE_PIXELS ** 0.5)
|
||||
return ValidationResult.error(
|
||||
f"{name} too large ({width}×{height} = {num_pixels:,} pixels). "
|
||||
f"Maximum: ~{MAX_IMAGE_PIXELS:,} pixels ({max_dim}×{max_dim})"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(
|
||||
width=width,
|
||||
height=height,
|
||||
@@ -205,24 +215,24 @@ def validate_image(
|
||||
mode=img.mode,
|
||||
format=img.format
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return ValidationResult.error(f"Could not read {name}: {e}")
|
||||
|
||||
|
||||
def validate_rsa_key(
|
||||
key_data: bytes,
|
||||
password: Optional[str] = None,
|
||||
password: str | None = None,
|
||||
required: bool = False
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate RSA private key.
|
||||
|
||||
|
||||
Args:
|
||||
key_data: PEM-encoded key bytes
|
||||
password: Password if key is encrypted
|
||||
required: Whether key is required
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult with key_size
|
||||
"""
|
||||
@@ -230,44 +240,44 @@ def validate_rsa_key(
|
||||
if required:
|
||||
return ValidationResult.error("RSA key is required")
|
||||
return ValidationResult.ok()
|
||||
|
||||
|
||||
try:
|
||||
private_key = load_rsa_key(key_data, password)
|
||||
key_size = private_key.key_size
|
||||
|
||||
|
||||
if key_size < MIN_RSA_BITS:
|
||||
return ValidationResult.error(
|
||||
f"RSA key must be at least {MIN_RSA_BITS} bits (got {key_size})"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(key_size=key_size)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return ValidationResult.error(str(e))
|
||||
|
||||
|
||||
def validate_security_factors(
|
||||
pin: str,
|
||||
rsa_key_data: Optional[bytes]
|
||||
rsa_key_data: bytes | None
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate that at least one security factor is provided.
|
||||
|
||||
|
||||
Args:
|
||||
pin: PIN string (may be empty)
|
||||
rsa_key_data: RSA key bytes (may be None/empty)
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
has_pin = bool(pin and pin.strip())
|
||||
has_key = bool(rsa_key_data and len(rsa_key_data) > 0)
|
||||
|
||||
|
||||
if not has_pin and not has_key:
|
||||
return ValidationResult.error(
|
||||
"You must provide at least a PIN or RSA Key"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(has_pin=has_pin, has_key=has_key)
|
||||
|
||||
|
||||
@@ -278,26 +288,26 @@ def validate_file_extension(
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate file extension.
|
||||
|
||||
|
||||
Args:
|
||||
filename: Filename to check
|
||||
allowed: Set of allowed extensions (lowercase, no dot)
|
||||
file_type: Name for error messages
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult with extension
|
||||
"""
|
||||
if not filename or '.' not in filename:
|
||||
return ValidationResult.error(f"{file_type} must have a file extension")
|
||||
|
||||
|
||||
ext = filename.rsplit('.', 1)[1].lower()
|
||||
|
||||
|
||||
if ext not in allowed:
|
||||
return ValidationResult.error(
|
||||
f"Unsupported {file_type.lower()} type: .{ext}. "
|
||||
f"Allowed: {', '.join(sorted('.' + e for e in allowed))}"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(extension=ext)
|
||||
|
||||
|
||||
@@ -314,53 +324,53 @@ def validate_key_file(filename: str) -> ValidationResult:
|
||||
def validate_key_password(password: str) -> ValidationResult:
|
||||
"""
|
||||
Validate password for key encryption.
|
||||
|
||||
|
||||
Args:
|
||||
password: Password string
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
if not password:
|
||||
return ValidationResult.error("Password is required")
|
||||
|
||||
|
||||
if len(password) < MIN_KEY_PASSWORD_LENGTH:
|
||||
return ValidationResult.error(
|
||||
f"Password must be at least {MIN_KEY_PASSWORD_LENGTH} characters"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(length=len(password))
|
||||
|
||||
|
||||
def validate_passphrase(passphrase: str) -> ValidationResult:
|
||||
"""
|
||||
Validate passphrase.
|
||||
|
||||
|
||||
v3.2.0: Recommend 4+ words for good entropy (since date is no longer used).
|
||||
|
||||
|
||||
Args:
|
||||
passphrase: Passphrase string
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult with word_count and optional warning
|
||||
"""
|
||||
if not passphrase or not passphrase.strip():
|
||||
return ValidationResult.error("Passphrase is required")
|
||||
|
||||
|
||||
words = passphrase.strip().split()
|
||||
|
||||
|
||||
if len(words) < MIN_PASSPHRASE_WORDS:
|
||||
return ValidationResult.error(
|
||||
f"Passphrase should have at least {MIN_PASSPHRASE_WORDS} words"
|
||||
)
|
||||
|
||||
|
||||
# Provide warning if below recommended length
|
||||
if len(words) < RECOMMENDED_PASSPHRASE_WORDS:
|
||||
return ValidationResult.ok(
|
||||
word_count=len(words),
|
||||
warning=f"Recommend {RECOMMENDED_PASSPHRASE_WORDS}+ words for better security"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(word_count=len(words))
|
||||
|
||||
|
||||
@@ -381,60 +391,60 @@ def validate_carrier(carrier_data: bytes) -> ValidationResult:
|
||||
def validate_embed_mode(mode: str) -> ValidationResult:
|
||||
"""
|
||||
Validate embedding mode.
|
||||
|
||||
|
||||
Args:
|
||||
mode: Embedding mode string
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
valid_modes = {EMBED_MODE_LSB, EMBED_MODE_DCT, EMBED_MODE_AUTO}
|
||||
|
||||
|
||||
if mode not in valid_modes:
|
||||
return ValidationResult.error(
|
||||
f"Invalid embed_mode: '{mode}'. Valid options: {', '.join(sorted(valid_modes))}"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(mode=mode)
|
||||
|
||||
|
||||
def validate_dct_output_format(format_str: str) -> ValidationResult:
|
||||
"""
|
||||
Validate DCT output format.
|
||||
|
||||
|
||||
Args:
|
||||
format_str: Output format ('png' or 'jpeg')
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
valid_formats = {'png', 'jpeg'}
|
||||
|
||||
|
||||
if format_str.lower() not in valid_formats:
|
||||
return ValidationResult.error(
|
||||
f"Invalid DCT output format: '{format_str}'. Valid options: {', '.join(sorted(valid_formats))}"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(format=format_str.lower())
|
||||
|
||||
|
||||
def validate_dct_color_mode(mode: str) -> ValidationResult:
|
||||
"""
|
||||
Validate DCT color mode.
|
||||
|
||||
|
||||
Args:
|
||||
mode: Color mode ('grayscale' or 'color')
|
||||
|
||||
|
||||
Returns:
|
||||
ValidationResult
|
||||
"""
|
||||
valid_modes = {'grayscale', 'color'}
|
||||
|
||||
|
||||
if mode.lower() not in valid_modes:
|
||||
return ValidationResult.error(
|
||||
f"Invalid DCT color mode: '{mode}'. Valid options: {', '.join(sorted(valid_modes))}"
|
||||
)
|
||||
|
||||
|
||||
return ValidationResult.ok(mode=mode.lower())
|
||||
|
||||
|
||||
@@ -456,7 +466,7 @@ def require_valid_message(message: str) -> None:
|
||||
raise MessageValidationError(result.error_message)
|
||||
|
||||
|
||||
def require_valid_payload(payload: Union[str, bytes, FilePayload]) -> None:
|
||||
def require_valid_payload(payload: str | bytes | FilePayload) -> None:
|
||||
"""Validate payload (text, bytes, or file), raising exception on failure."""
|
||||
result = validate_payload(payload)
|
||||
if not result.is_valid:
|
||||
@@ -472,7 +482,7 @@ def require_valid_image(image_data: bytes, name: str = "Image") -> None:
|
||||
|
||||
def require_valid_rsa_key(
|
||||
key_data: bytes,
|
||||
password: Optional[str] = None,
|
||||
password: str | None = None,
|
||||
required: bool = False
|
||||
) -> None:
|
||||
"""Validate RSA key, raising exception on failure."""
|
||||
@@ -481,7 +491,7 @@ def require_valid_rsa_key(
|
||||
raise KeyValidationError(result.error_message)
|
||||
|
||||
|
||||
def require_security_factors(pin: str, rsa_key_data: Optional[bytes]) -> None:
|
||||
def require_security_factors(pin: str, rsa_key_data: bytes | None) -> None:
|
||||
"""Validate security factors, raising exception on failure."""
|
||||
result = validate_security_factors(pin, rsa_key_data)
|
||||
if not result.is_valid:
|
||||
|
||||
Reference in New Issue
Block a user