Apply black formatter to all Python files

Reformatted 29 files for consistent code style and CI compliance.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee
2026-01-02 17:44:41 -05:00
parent 221678d934
commit afa88bc73b
29 changed files with 2067 additions and 1814 deletions

View File

@@ -31,7 +31,7 @@ from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# Add parent to path for development # Add parent to path for development
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src')) sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
from stegasoo import ( from stegasoo import (
MAX_FILE_PAYLOAD_SIZE, MAX_FILE_PAYLOAD_SIZE,
@@ -71,6 +71,7 @@ try:
extract_key_from_qr, extract_key_from_qr,
has_qr_read, has_qr_read,
) )
HAS_QR_READ = has_qr_read() HAS_QR_READ = has_qr_read()
except ImportError: except ImportError:
HAS_QR_READ = False HAS_QR_READ = False
@@ -130,6 +131,7 @@ DctOutputFormatType = Literal["png", "jpeg"]
# MODELS # MODELS
# ============================================================================ # ============================================================================
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
use_pin: bool = True use_pin: bool = True
use_rsa: bool = False use_rsa: bool = False
@@ -139,7 +141,7 @@ class GenerateRequest(BaseModel):
default=DEFAULT_PASSPHRASE_WORDS, default=DEFAULT_PASSPHRASE_WORDS,
ge=MIN_PASSPHRASE_WORDS, ge=MIN_PASSPHRASE_WORDS,
le=MAX_PASSPHRASE_WORDS, le=MAX_PASSPHRASE_WORDS,
description="Words per passphrase (v3.2.0: default increased to 4)" description="Words per passphrase (v3.2.0: default increased to 4)",
) )
@@ -150,8 +152,7 @@ class GenerateResponse(BaseModel):
entropy: dict[str, int] entropy: dict[str, int]
# Legacy field for compatibility # Legacy field for compatibility
phrases: dict[str, str] | None = Field( phrases: dict[str, str] | None = Field(
default=None, default=None, description="Deprecated: Use 'passphrase' instead"
description="Deprecated: Use 'passphrase' instead"
) )
@@ -166,24 +167,25 @@ class EncodeRequest(BaseModel):
# Channel key (v4.0.0) # Channel key (v4.0.0)
channel_key: str | None = Field( channel_key: str | None = Field(
default=None, default=None,
description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key" description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key",
) )
embed_mode: EmbedModeType = Field( embed_mode: EmbedModeType = Field(
default="lsb", default="lsb",
description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)" description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)",
) )
dct_output_format: DctOutputFormatType = Field( dct_output_format: DctOutputFormatType = Field(
default="png", default="png",
description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode." description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode.",
) )
dct_color_mode: DctColorModeType = Field( dct_color_mode: DctColorModeType = Field(
default="grayscale", default="grayscale",
description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode." description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode.",
) )
class EncodeFileRequest(BaseModel): class EncodeFileRequest(BaseModel):
"""Request for embedding a file (base64-encoded).""" """Request for embedding a file (base64-encoded)."""
file_data_base64: str file_data_base64: str
filename: str filename: str
mime_type: str | None = None mime_type: str | None = None
@@ -196,19 +198,19 @@ class EncodeFileRequest(BaseModel):
# Channel key (v4.0.0) # Channel key (v4.0.0)
channel_key: str | None = Field( channel_key: str | None = Field(
default=None, default=None,
description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key" description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key",
) )
embed_mode: EmbedModeType = Field( embed_mode: EmbedModeType = Field(
default="lsb", default="lsb",
description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)" description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)",
) )
dct_output_format: DctOutputFormatType = Field( dct_output_format: DctOutputFormatType = Field(
default="png", default="png",
description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode." description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode.",
) )
dct_color_mode: DctColorModeType = Field( dct_color_mode: DctColorModeType = Field(
default="grayscale", default="grayscale",
description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode." description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode.",
) )
@@ -218,30 +220,23 @@ class EncodeResponse(BaseModel):
capacity_used_percent: float capacity_used_percent: float
embed_mode: str = Field(description="Embedding mode used: 'lsb' or 'dct'") embed_mode: str = Field(description="Embedding mode used: 'lsb' or 'dct'")
output_format: str = Field( output_format: str = Field(
default="png", default="png", description="Output format: 'png' or 'jpeg' (for DCT mode)"
description="Output format: 'png' or 'jpeg' (for DCT mode)"
) )
color_mode: str = Field( color_mode: str = Field(
default="color", default="color",
description="Color mode: 'color' (LSB/DCT color) or 'grayscale' (DCT grayscale)" description="Color mode: 'color' (LSB/DCT color) or 'grayscale' (DCT grayscale)",
) )
# Channel key info (v4.0.0) # Channel key info (v4.0.0)
channel_mode: str = Field( channel_mode: str = Field(default="public", description="Channel mode: 'public' or 'private'")
default="public",
description="Channel mode: 'public' or 'private'"
)
channel_fingerprint: str | None = Field( channel_fingerprint: str | None = Field(
default=None, default=None, description="Channel key fingerprint (if private mode)"
description="Channel key fingerprint (if private mode)"
) )
# Legacy fields (v3.2.0: no longer used in crypto) # Legacy fields (v3.2.0: no longer used in crypto)
date_used: str | None = Field( date_used: str | None = Field(
default=None, default=None, description="Deprecated: Date no longer used in v3.2.0"
description="Deprecated: Date no longer used in v3.2.0"
) )
day_of_week: str | None = Field( day_of_week: str | None = Field(
default=None, default=None, description="Deprecated: Date no longer used in v3.2.0"
description="Deprecated: Date no longer used in v3.2.0"
) )
@@ -255,16 +250,16 @@ class DecodeRequest(BaseModel):
# Channel key (v4.0.0) # Channel key (v4.0.0)
channel_key: str | None = Field( channel_key: str | None = Field(
default=None, default=None,
description="Channel key for decryption. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key" description="Channel key for decryption. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key",
) )
embed_mode: ExtractModeType = Field( embed_mode: ExtractModeType = Field(
default="auto", default="auto", description="Extraction mode: 'auto' (default), 'lsb', or 'dct'"
description="Extraction mode: 'auto' (default), 'lsb', or 'dct'"
) )
class DecodeResponse(BaseModel): class DecodeResponse(BaseModel):
"""Response for decode - can be text or file.""" """Response for decode - can be text or file."""
payload_type: str # 'text' or 'file' payload_type: str # 'text' or 'file'
message: str | None = None # For text message: str | None = None # For text
file_data_base64: str | None = None # For file (base64-encoded) file_data_base64: str | None = None # For file (base64-encoded)
@@ -274,6 +269,7 @@ class DecodeResponse(BaseModel):
class ModeCapacity(BaseModel): class ModeCapacity(BaseModel):
"""Capacity info for a single mode.""" """Capacity info for a single mode."""
capacity_bytes: int capacity_bytes: int
capacity_kb: float capacity_kb: float
available: bool available: bool
@@ -287,22 +283,22 @@ class ImageInfoResponse(BaseModel):
capacity_bytes: int = Field(description="LSB mode capacity (for backwards compatibility)") capacity_bytes: int = Field(description="LSB mode capacity (for backwards compatibility)")
capacity_kb: int = Field(description="LSB mode capacity in KB") capacity_kb: int = Field(description="LSB mode capacity in KB")
modes: dict[str, ModeCapacity] | None = Field( modes: dict[str, ModeCapacity] | None = Field(
default=None, default=None, description="Capacity by embedding mode (v3.0+)"
description="Capacity by embedding mode (v3.0+)"
) )
class CompareModesRequest(BaseModel): class CompareModesRequest(BaseModel):
"""Request for comparing embedding modes.""" """Request for comparing embedding modes."""
carrier_image_base64: str carrier_image_base64: str
payload_size: int | None = Field( payload_size: int | None = Field(
default=None, default=None, description="Optional payload size to check if it fits"
description="Optional payload size to check if it fits"
) )
class CompareModesResponse(BaseModel): class CompareModesResponse(BaseModel):
"""Response comparing LSB and DCT modes.""" """Response comparing LSB and DCT modes."""
width: int width: int
height: int height: int
lsb: dict lsb: dict
@@ -313,6 +309,7 @@ class CompareModesResponse(BaseModel):
class DctModeInfo(BaseModel): class DctModeInfo(BaseModel):
"""Detailed DCT mode information.""" """Detailed DCT mode information."""
available: bool available: bool
name: str name: str
description: str description: str
@@ -324,6 +321,7 @@ class DctModeInfo(BaseModel):
class ChannelStatusResponse(BaseModel): class ChannelStatusResponse(BaseModel):
"""Response for channel key status (v4.0.0).""" """Response for channel key status (v4.0.0)."""
mode: str = Field(description="'public' or 'private'") mode: str = Field(description="'public' or 'private'")
configured: bool = Field(description="Whether a channel key is configured") configured: bool = Field(description="Whether a channel key is configured")
fingerprint: str | None = Field(default=None, description="Key fingerprint (partial)") fingerprint: str | None = Field(default=None, description="Key fingerprint (partial)")
@@ -333,6 +331,7 @@ class ChannelStatusResponse(BaseModel):
class ChannelGenerateResponse(BaseModel): class ChannelGenerateResponse(BaseModel):
"""Response for channel key generation (v4.0.0).""" """Response for channel key generation (v4.0.0)."""
key: str = Field(description="Generated channel key") key: str = Field(description="Generated channel key")
fingerprint: str = Field(description="Key fingerprint") fingerprint: str = Field(description="Key fingerprint")
saved: bool = Field(default=False, description="Whether key was saved to config") saved: bool = Field(default=False, description="Whether key was saved to config")
@@ -341,19 +340,18 @@ class ChannelGenerateResponse(BaseModel):
class ChannelSetRequest(BaseModel): class ChannelSetRequest(BaseModel):
"""Request to set channel key (v4.0.0).""" """Request to set channel key (v4.0.0)."""
key: str = Field(description="Channel key to set") key: str = Field(description="Channel key to set")
location: str = Field(default="user", description="'user' or 'project'") location: str = Field(default="user", description="'user' or 'project'")
class ModesResponse(BaseModel): class ModesResponse(BaseModel):
"""Response showing available embedding modes.""" """Response showing available embedding modes."""
lsb: dict lsb: dict
dct: DctModeInfo dct: DctModeInfo
# Channel key status (v4.0.0) # Channel key status (v4.0.0)
channel: dict | None = Field( channel: dict | None = Field(default=None, description="Channel key status (v4.0.0)")
default=None,
description="Channel key status (v4.0.0)"
)
class StatusResponse(BaseModel): class StatusResponse(BaseModel):
@@ -363,18 +361,10 @@ class StatusResponse(BaseModel):
has_dct: bool has_dct: bool
max_payload_kb: int max_payload_kb: int
available_modes: list[str] available_modes: list[str]
dct_features: dict | None = Field( dct_features: dict | None = Field(default=None, description="DCT mode features (v3.0.1+)")
default=None,
description="DCT mode features (v3.0.1+)"
)
# Channel key status (v4.0.0) # Channel key status (v4.0.0)
channel: dict | None = Field( channel: dict | None = Field(default=None, description="Channel key status (v4.0.0)")
default=None, breaking_changes: dict = Field(description="v4.0.0 breaking changes")
description="Channel key status (v4.0.0)"
)
breaking_changes: dict = Field(
description="v4.0.0 breaking changes"
)
class QrExtractResponse(BaseModel): class QrExtractResponse(BaseModel):
@@ -385,6 +375,7 @@ class QrExtractResponse(BaseModel):
class WillFitRequest(BaseModel): class WillFitRequest(BaseModel):
"""Request to check if payload will fit.""" """Request to check if payload will fit."""
carrier_image_base64: str carrier_image_base64: str
payload_size: int payload_size: int
embed_mode: EmbedModeType = "lsb" embed_mode: EmbedModeType = "lsb"
@@ -392,6 +383,7 @@ class WillFitRequest(BaseModel):
class WillFitResponse(BaseModel): class WillFitResponse(BaseModel):
"""Response for will_fit check.""" """Response for will_fit check."""
fits: bool fits: bool
payload_size: int payload_size: int
capacity: int capacity: int
@@ -409,6 +401,7 @@ class ErrorResponse(BaseModel):
# HELPER: RESOLVE CHANNEL KEY # HELPER: RESOLVE CHANNEL KEY
# ============================================================================ # ============================================================================
def _resolve_channel_key(channel_key: str | None) -> str | None: def _resolve_channel_key(channel_key: str | None) -> str | None:
""" """
Resolve channel key from API parameter. Resolve channel key from API parameter.
@@ -436,8 +429,7 @@ def _resolve_channel_key(channel_key: str | None) -> str | None:
# Explicit key - validate format # Explicit key - validate format
if not validate_channel_key(channel_key): if not validate_channel_key(channel_key):
raise HTTPException( raise HTTPException(
400, 400, "Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
"Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
) )
return channel_key return channel_key
@@ -461,7 +453,7 @@ def _get_channel_info(channel_key: str | None) -> tuple[str, str | None]:
# Auto mode - check server config # Auto mode - check server config
if has_channel_key(): if has_channel_key():
status = get_channel_status() status = get_channel_status()
return "private", status.get('fingerprint') return "private", status.get("fingerprint")
return "public", None return "public", None
@@ -470,6 +462,7 @@ def _get_channel_info(channel_key: str | None) -> tuple[str, str | None]:
# ROUTES - STATUS & INFO # ROUTES - STATUS & INFO
# ============================================================================ # ============================================================================
@app.get("/", response_model=StatusResponse) @app.get("/", response_model=StatusResponse)
async def root(): async def root():
"""Get API status and configuration.""" """Get API status and configuration."""
@@ -488,10 +481,10 @@ async def root():
# Channel key status (v4.0.0) # Channel key status (v4.0.0)
channel_status = get_channel_status() channel_status = get_channel_status()
channel_info = { channel_info = {
"mode": channel_status['mode'], "mode": channel_status["mode"],
"configured": channel_status['configured'], "configured": channel_status["configured"],
"fingerprint": channel_status.get('fingerprint'), "fingerprint": channel_status.get("fingerprint"),
"source": channel_status.get('source'), "source": channel_status.get("source"),
} }
return StatusResponse( return StatusResponse(
@@ -510,8 +503,8 @@ async def root():
"v3_notes": { "v3_notes": {
"date_removed": "No date_str parameter needed - encode/decode anytime", "date_removed": "No date_str parameter needed - encode/decode anytime",
"passphrase_renamed": "day_phrase → passphrase (single passphrase, no daily rotation)", "passphrase_renamed": "day_phrase → passphrase (single passphrase, no daily rotation)",
} },
} },
) )
@@ -525,9 +518,9 @@ async def api_modes():
# Channel status # Channel status
channel_status = get_channel_status() channel_status = get_channel_status()
channel_info = { channel_info = {
"mode": channel_status['mode'], "mode": channel_status["mode"],
"configured": channel_status['configured'], "configured": channel_status["configured"],
"fingerprint": channel_status.get('fingerprint'), "fingerprint": channel_status.get("fingerprint"),
} }
return ModesResponse( return ModesResponse(
@@ -555,6 +548,7 @@ async def api_modes():
# ROUTES - CHANNEL KEY (v4.0.0) # ROUTES - CHANNEL KEY (v4.0.0)
# ============================================================================ # ============================================================================
@app.get("/channel/status", response_model=ChannelStatusResponse) @app.get("/channel/status", response_model=ChannelStatusResponse)
async def api_channel_status( async def api_channel_status(
reveal: bool = Query(False, description="Include full key in response") reveal: bool = Query(False, description="Include full key in response")
@@ -570,11 +564,11 @@ async def api_channel_status(
status = get_channel_status() status = get_channel_status()
return ChannelStatusResponse( return ChannelStatusResponse(
mode=status['mode'], mode=status["mode"],
configured=status['configured'], configured=status["configured"],
fingerprint=status.get('fingerprint'), fingerprint=status.get("fingerprint"),
source=status.get('source'), source=status.get("source"),
key=status.get('key') if reveal and status['configured'] else None, key=status.get("key") if reveal and status["configured"] else None,
) )
@@ -601,11 +595,11 @@ async def api_channel_generate(
save_location = None save_location = None
if save: if save:
set_channel_key(key, location='user') set_channel_key(key, location="user")
saved = True saved = True
save_location = "~/.stegasoo/channel.key" save_location = "~/.stegasoo/channel.key"
elif save_project: elif save_project:
set_channel_key(key, location='project') set_channel_key(key, location="project")
saved = True saved = True
save_location = "./config/channel.key" save_location = "./config/channel.key"
@@ -626,11 +620,10 @@ async def api_channel_set(request: ChannelSetRequest):
""" """
if not validate_channel_key(request.key): if not validate_channel_key(request.key):
raise HTTPException( raise HTTPException(
400, 400, "Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
"Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
) )
if request.location not in ('user', 'project'): if request.location not in ("user", "project"):
raise HTTPException(400, "location must be 'user' or 'project'") raise HTTPException(400, "location must be 'user' or 'project'")
set_channel_key(request.key, location=request.location) set_channel_key(request.key, location=request.location)
@@ -638,8 +631,8 @@ async def api_channel_set(request: ChannelSetRequest):
status = get_channel_status() status = get_channel_status()
return { return {
"success": True, "success": True,
"location": status.get('source'), "location": status.get("source"),
"fingerprint": status.get('fingerprint'), "fingerprint": status.get("fingerprint"),
} }
@@ -655,9 +648,9 @@ async def api_channel_clear(
Note: Does not affect environment variables. Note: Does not affect environment variables.
""" """
if location == "all": if location == "all":
clear_channel_key(location='user') clear_channel_key(location="user")
clear_channel_key(location='project') clear_channel_key(location="project")
elif location in ('user', 'project'): elif location in ("user", "project"):
clear_channel_key(location=location) clear_channel_key(location=location)
else: else:
raise HTTPException(400, "location must be 'user', 'project', or 'all'") raise HTTPException(400, "location must be 'user', 'project', or 'all'")
@@ -665,9 +658,9 @@ async def api_channel_clear(
status = get_channel_status() status = get_channel_status()
return { return {
"success": True, "success": True,
"mode": status['mode'], "mode": status["mode"],
"still_configured": status['configured'], "still_configured": status["configured"],
"remaining_source": status.get('source'), "remaining_source": status.get("source"),
} }
@@ -684,28 +677,30 @@ async def api_compare_modes(request: CompareModesRequest):
comparison = compare_modes(carrier) comparison = compare_modes(carrier)
response = CompareModesResponse( response = CompareModesResponse(
width=comparison['width'], width=comparison["width"],
height=comparison['height'], height=comparison["height"],
lsb={ lsb={
"capacity_bytes": comparison['lsb']['capacity_bytes'], "capacity_bytes": comparison["lsb"]["capacity_bytes"],
"capacity_kb": round(comparison['lsb']['capacity_kb'], 1), "capacity_kb": round(comparison["lsb"]["capacity_kb"], 1),
"available": True, "available": True,
"output_format": comparison['lsb']['output'], "output_format": comparison["lsb"]["output"],
}, },
dct={ dct={
"capacity_bytes": comparison['dct']['capacity_bytes'], "capacity_bytes": comparison["dct"]["capacity_bytes"],
"capacity_kb": round(comparison['dct']['capacity_kb'], 1), "capacity_kb": round(comparison["dct"]["capacity_kb"], 1),
"available": comparison['dct']['available'], "available": comparison["dct"]["available"],
"output_formats": ["png", "jpeg"], "output_formats": ["png", "jpeg"],
"color_modes": ["grayscale", "color"], "color_modes": ["grayscale", "color"],
"ratio_vs_lsb_percent": round(comparison['dct']['ratio_vs_lsb'], 1), "ratio_vs_lsb_percent": round(comparison["dct"]["ratio_vs_lsb"], 1),
}, },
recommendation="lsb" if not comparison['dct']['available'] else "dct for stealth, lsb for capacity" recommendation=(
"lsb" if not comparison["dct"]["available"] else "dct for stealth, lsb for capacity"
),
) )
if request.payload_size: if request.payload_size:
fits_lsb = request.payload_size <= comparison['lsb']['capacity_bytes'] fits_lsb = request.payload_size <= comparison["lsb"]["capacity_bytes"]
fits_dct = request.payload_size <= comparison['dct']['capacity_bytes'] fits_dct = request.payload_size <= comparison["dct"]["capacity_bytes"]
response.payload_check = { response.payload_check = {
"size_bytes": request.payload_size, "size_bytes": request.payload_size,
@@ -714,7 +709,7 @@ async def api_compare_modes(request: CompareModesRequest):
} }
# Update recommendation based on payload # Update recommendation based on payload
if fits_dct and comparison['dct']['available']: if fits_dct and comparison["dct"]["available"]:
response.recommendation = "dct (payload fits, better stealth)" response.recommendation = "dct (payload fits, better stealth)"
elif fits_lsb: elif fits_lsb:
response.recommendation = "lsb (payload too large for dct)" response.recommendation = "lsb (payload too large for dct)"
@@ -743,12 +738,12 @@ async def api_will_fit(request: WillFitRequest):
result = will_fit_by_mode(request.payload_size, carrier, embed_mode=request.embed_mode) result = will_fit_by_mode(request.payload_size, carrier, embed_mode=request.embed_mode)
return WillFitResponse( return WillFitResponse(
fits=result['fits'], fits=result["fits"],
payload_size=result['payload_size'], payload_size=result["payload_size"],
capacity=result['capacity'], capacity=result["capacity"],
usage_percent=round(result['usage_percent'], 1), usage_percent=round(result["usage_percent"], 1),
headroom=result['headroom'], headroom=result["headroom"],
mode=request.embed_mode mode=request.embed_mode,
) )
except HTTPException: except HTTPException:
@@ -761,6 +756,7 @@ async def api_will_fit(request: WillFitRequest):
# ROUTES - QR CODE # ROUTES - QR CODE
# ============================================================================ # ============================================================================
@app.post("/extract-key-from-qr", response_model=QrExtractResponse) @app.post("/extract-key-from-qr", response_model=QrExtractResponse)
async def api_extract_key_from_qr( async def api_extract_key_from_qr(
qr_image: UploadFile = File(..., description="QR code image containing RSA key") qr_image: UploadFile = File(..., description="QR code image containing RSA key")
@@ -772,10 +768,7 @@ async def api_extract_key_from_qr(
Returns the PEM-encoded key if found. Returns the PEM-encoded key if found.
""" """
if not HAS_QR_READ: if not HAS_QR_READ:
raise HTTPException( raise HTTPException(501, "QR code reading not available. Install pyzbar and libzbar.")
501,
"QR code reading not available. Install pyzbar and libzbar."
)
try: try:
image_data = await qr_image.read() image_data = await qr_image.read()
@@ -784,10 +777,7 @@ async def api_extract_key_from_qr(
if key_pem: if key_pem:
return QrExtractResponse(success=True, key_pem=key_pem) return QrExtractResponse(success=True, key_pem=key_pem)
else: else:
return QrExtractResponse( return QrExtractResponse(success=False, error="No valid RSA key found in QR code")
success=False,
error="No valid RSA key found in QR code"
)
except Exception as e: except Exception as e:
return QrExtractResponse(success=False, error=str(e)) return QrExtractResponse(success=False, error=str(e))
@@ -796,6 +786,7 @@ async def api_extract_key_from_qr(
# ROUTES - GENERATE # ROUTES - GENERATE
# ============================================================================ # ============================================================================
@app.post("/generate", response_model=GenerateResponse) @app.post("/generate", response_model=GenerateResponse)
async def api_generate(request: GenerateRequest): async def api_generate(request: GenerateRequest):
""" """
@@ -829,9 +820,9 @@ async def api_generate(request: GenerateRequest):
"passphrase": creds.passphrase_entropy, "passphrase": creds.passphrase_entropy,
"pin": creds.pin_entropy, "pin": creds.pin_entropy,
"rsa": creds.rsa_entropy, "rsa": creds.rsa_entropy,
"total": creds.total_entropy "total": creds.total_entropy,
}, },
phrases=None # Legacy field removed phrases=None, # Legacy field removed
) )
except Exception as e: except Exception as e:
raise HTTPException(500, str(e)) raise HTTPException(500, str(e))
@@ -841,6 +832,7 @@ async def api_generate(request: GenerateRequest):
# HELPER FUNCTION FOR DCT PARAMETERS # HELPER FUNCTION FOR DCT PARAMETERS
# ============================================================================ # ============================================================================
def _get_dct_params(embed_mode: str, dct_output_format: str, dct_color_mode: str) -> dict: def _get_dct_params(embed_mode: str, dct_output_format: str, dct_color_mode: str) -> dict:
""" """
Get DCT-specific parameters if DCT mode is selected. Get DCT-specific parameters if DCT mode is selected.
@@ -876,6 +868,7 @@ def _get_output_info(embed_mode: str, dct_output_format: str, dct_color_mode: st
# ROUTES - ENCODE (JSON) # ROUTES - ENCODE (JSON)
# ============================================================================ # ============================================================================
@app.post("/encode", response_model=EncodeResponse) @app.post("/encode", response_model=EncodeResponse)
async def api_encode(request: EncodeRequest): async def api_encode(request: EncodeRequest):
""" """
@@ -900,9 +893,7 @@ async def api_encode(request: EncodeRequest):
# Get DCT parameters # Get DCT parameters
dct_params = _get_dct_params( dct_params = _get_dct_params(
request.embed_mode, request.embed_mode, request.dct_output_format, request.dct_color_mode
request.dct_output_format,
request.dct_color_mode
) )
# v4.0.0: Include channel_key # v4.0.0: Include channel_key
@@ -919,12 +910,10 @@ async def api_encode(request: EncodeRequest):
**dct_params, **dct_params,
) )
stego_b64 = base64.b64encode(result.stego_image).decode('utf-8') stego_b64 = base64.b64encode(result.stego_image).decode("utf-8")
output_format, color_mode, _ = _get_output_info( output_format, color_mode, _ = _get_output_info(
request.embed_mode, request.embed_mode, request.dct_output_format, request.dct_color_mode
request.dct_output_format,
request.dct_color_mode
) )
# Get channel info for response # Get channel info for response
@@ -975,16 +964,12 @@ async def api_encode_file(request: EncodeFileRequest):
rsa_key = base64.b64decode(request.rsa_key_base64) if request.rsa_key_base64 else None rsa_key = base64.b64decode(request.rsa_key_base64) if request.rsa_key_base64 else None
payload = FilePayload( payload = FilePayload(
data=file_data, data=file_data, filename=request.filename, mime_type=request.mime_type
filename=request.filename,
mime_type=request.mime_type
) )
# Get DCT parameters # Get DCT parameters
dct_params = _get_dct_params( dct_params = _get_dct_params(
request.embed_mode, request.embed_mode, request.dct_output_format, request.dct_color_mode
request.dct_output_format,
request.dct_color_mode
) )
# v4.0.0: Include channel_key # v4.0.0: Include channel_key
@@ -1001,12 +986,10 @@ async def api_encode_file(request: EncodeFileRequest):
**dct_params, **dct_params,
) )
stego_b64 = base64.b64encode(result.stego_image).decode('utf-8') stego_b64 = base64.b64encode(result.stego_image).decode("utf-8")
output_format, color_mode, _ = _get_output_info( output_format, color_mode, _ = _get_output_info(
request.embed_mode, request.embed_mode, request.dct_output_format, request.dct_color_mode
request.dct_output_format,
request.dct_color_mode
) )
# Get channel info for response # Get channel info for response
@@ -1037,6 +1020,7 @@ async def api_encode_file(request: EncodeFileRequest):
# ROUTES - DECODE (JSON) # ROUTES - DECODE (JSON)
# ============================================================================ # ============================================================================
@app.post("/decode", response_model=DecodeResponse) @app.post("/decode", response_model=DecodeResponse)
async def api_decode(request: DecodeRequest): async def api_decode(request: DecodeRequest):
""" """
@@ -1073,21 +1057,18 @@ async def api_decode(request: DecodeRequest):
if result.is_file: if result.is_file:
return DecodeResponse( return DecodeResponse(
payload_type='file', payload_type="file",
file_data_base64=base64.b64encode(result.file_data).decode('utf-8'), file_data_base64=base64.b64encode(result.file_data).decode("utf-8"),
filename=result.filename, filename=result.filename,
mime_type=result.mime_type mime_type=result.mime_type,
) )
else: else:
return DecodeResponse( return DecodeResponse(payload_type="text", message=result.message)
payload_type='text',
message=result.message
)
except DecryptionError as e: except DecryptionError as e:
# Provide helpful error message for channel key issues # Provide helpful error message for channel key issues
error_msg = str(e) error_msg = str(e)
if 'channel key' in error_msg.lower(): if "channel key" in error_msg.lower():
raise HTTPException(401, error_msg) raise HTTPException(401, error_msg)
raise HTTPException(401, "Decryption failed. Check credentials.") raise HTTPException(401, "Decryption failed. Check credentials.")
except StegasooError as e: except StegasooError as e:
@@ -1100,6 +1081,7 @@ async def api_decode(request: DecodeRequest):
# ROUTES - ENCODE/DECODE (MULTIPART) # ROUTES - ENCODE/DECODE (MULTIPART)
# ============================================================================ # ============================================================================
@app.post("/encode/multipart") @app.post("/encode/multipart")
async def api_encode_multipart( async def api_encode_multipart(
passphrase: str = Form(..., description="Passphrase (v3.2.0: renamed from day_phrase)"), passphrase: str = Form(..., description="Passphrase (v3.2.0: renamed from day_phrase)"),
@@ -1112,7 +1094,9 @@ async def api_encode_multipart(
rsa_key_qr: UploadFile | None = File(None), rsa_key_qr: UploadFile | None = File(None),
rsa_password: str = Form(""), rsa_password: str = Form(""),
# Channel key (v4.0.0) # Channel key (v4.0.0)
channel_key: str = Form("auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"), channel_key: str = Form(
"auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"
),
embed_mode: str = Form("lsb"), embed_mode: str = Form("lsb"),
dct_output_format: str = Form("png"), dct_output_format: str = Form("png"),
dct_color_mode: str = Form("grayscale"), dct_color_mode: str = Form("grayscale"),
@@ -1162,14 +1146,13 @@ async def api_encode_multipart(
elif rsa_key_qr and rsa_key_qr.filename: elif rsa_key_qr and rsa_key_qr.filename:
if not HAS_QR_READ: if not HAS_QR_READ:
raise HTTPException( raise HTTPException(
501, 501, "QR code reading not available. Install pyzbar and libzbar."
"QR code reading not available. Install pyzbar and libzbar."
) )
qr_image_data = await rsa_key_qr.read() qr_image_data = await rsa_key_qr.read()
key_pem = extract_key_from_qr(qr_image_data) key_pem = extract_key_from_qr(qr_image_data)
if not key_pem: if not key_pem:
raise HTTPException(400, "Could not extract RSA key from QR code image") raise HTTPException(400, "Could not extract RSA key from QR code image")
rsa_key_data = key_pem.encode('utf-8') rsa_key_data = key_pem.encode("utf-8")
rsa_key_from_qr = True rsa_key_from_qr = True
# QR code keys are never password-protected # QR code keys are never password-protected
@@ -1179,9 +1162,7 @@ async def api_encode_multipart(
if payload_file and payload_file.filename: if payload_file and payload_file.filename:
file_data = await payload_file.read() file_data = await payload_file.read()
payload = FilePayload( payload = FilePayload(
data=file_data, data=file_data, filename=payload_file.filename, mime_type=payload_file.content_type
filename=payload_file.filename,
mime_type=payload_file.content_type
) )
elif message: elif message:
payload = message payload = message
@@ -1251,7 +1232,9 @@ async def api_decode_multipart(
rsa_key_qr: UploadFile | None = File(None), rsa_key_qr: UploadFile | None = File(None),
rsa_password: str = Form(""), rsa_password: str = Form(""),
# Channel key (v4.0.0) # Channel key (v4.0.0)
channel_key: str = Form("auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"), channel_key: str = Form(
"auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"
),
embed_mode: str = Form("auto"), embed_mode: str = Form("auto"),
): ):
""" """
@@ -1291,14 +1274,13 @@ async def api_decode_multipart(
elif rsa_key_qr and rsa_key_qr.filename: elif rsa_key_qr and rsa_key_qr.filename:
if not HAS_QR_READ: if not HAS_QR_READ:
raise HTTPException( raise HTTPException(
501, 501, "QR code reading not available. Install pyzbar and libzbar."
"QR code reading not available. Install pyzbar and libzbar."
) )
qr_image_data = await rsa_key_qr.read() qr_image_data = await rsa_key_qr.read()
key_pem = extract_key_from_qr(qr_image_data) key_pem = extract_key_from_qr(qr_image_data)
if not key_pem: if not key_pem:
raise HTTPException(400, "Could not extract RSA key from QR code image") raise HTTPException(400, "Could not extract RSA key from QR code image")
rsa_key_data = key_pem.encode('utf-8') rsa_key_data = key_pem.encode("utf-8")
rsa_key_from_qr = True rsa_key_from_qr = True
# QR code keys are never password-protected # QR code keys are never password-protected
@@ -1318,20 +1300,17 @@ async def api_decode_multipart(
if result.is_file: if result.is_file:
return DecodeResponse( return DecodeResponse(
payload_type='file', payload_type="file",
file_data_base64=base64.b64encode(result.file_data).decode('utf-8'), file_data_base64=base64.b64encode(result.file_data).decode("utf-8"),
filename=result.filename, filename=result.filename,
mime_type=result.mime_type mime_type=result.mime_type,
) )
else: else:
return DecodeResponse( return DecodeResponse(payload_type="text", message=result.message)
payload_type='text',
message=result.message
)
except DecryptionError as e: except DecryptionError as e:
error_msg = str(e) error_msg = str(e)
if 'channel key' in error_msg.lower(): if "channel key" in error_msg.lower():
raise HTTPException(401, error_msg) raise HTTPException(401, error_msg)
raise HTTPException(401, "Decryption failed. Check credentials.") raise HTTPException(401, "Decryption failed. Check credentials.")
except StegasooError as e: except StegasooError as e:
@@ -1346,10 +1325,11 @@ async def api_decode_multipart(
# ROUTES - IMAGE INFO # ROUTES - IMAGE INFO
# ============================================================================ # ============================================================================
@app.post("/image/info", response_model=ImageInfoResponse) @app.post("/image/info", response_model=ImageInfoResponse)
async def api_image_info( async def api_image_info(
image: UploadFile = File(...), image: UploadFile = File(...),
include_modes: bool = Query(True, description="Include capacity by mode (v3.0+)") include_modes: bool = Query(True, description="Include capacity by mode (v3.0+)"),
): ):
""" """
Get information about an image's capacity. Get information about an image's capacity.
@@ -1363,29 +1343,29 @@ async def api_image_info(
if not result.is_valid: if not result.is_valid:
raise HTTPException(400, result.error_message) raise HTTPException(400, result.error_message)
capacity = calculate_capacity_by_mode(image_data, 'lsb') capacity = calculate_capacity_by_mode(image_data, "lsb")
response = ImageInfoResponse( response = ImageInfoResponse(
width=result.details['width'], width=result.details["width"],
height=result.details['height'], height=result.details["height"],
pixels=result.details['pixels'], pixels=result.details["pixels"],
capacity_bytes=capacity, capacity_bytes=capacity,
capacity_kb=capacity // 1024 capacity_kb=capacity // 1024,
) )
if include_modes: if include_modes:
comparison = compare_modes(image_data) comparison = compare_modes(image_data)
response.modes = { response.modes = {
"lsb": ModeCapacity( "lsb": ModeCapacity(
capacity_bytes=comparison['lsb']['capacity_bytes'], capacity_bytes=comparison["lsb"]["capacity_bytes"],
capacity_kb=round(comparison['lsb']['capacity_kb'], 1), capacity_kb=round(comparison["lsb"]["capacity_kb"], 1),
available=True, available=True,
output_format=comparison['lsb']['output'], output_format=comparison["lsb"]["output"],
), ),
"dct": ModeCapacity( "dct": ModeCapacity(
capacity_bytes=comparison['dct']['capacity_bytes'], capacity_bytes=comparison["dct"]["capacity_bytes"],
capacity_kb=round(comparison['dct']['capacity_kb'], 1), capacity_kb=round(comparison["dct"]["capacity_kb"], 1),
available=comparison['dct']['available'], available=comparison["dct"]["available"],
output_format="PNG/JPEG (grayscale or color)", output_format="PNG/JPEG (grayscale or color)",
), ),
} }
@@ -1402,18 +1382,17 @@ async def api_image_info(
# ERROR HANDLERS # ERROR HANDLERS
# ============================================================================ # ============================================================================
@app.exception_handler(StegasooError) @app.exception_handler(StegasooError)
async def stegasoo_error_handler(request, exc): async def stegasoo_error_handler(request, exc):
return JSONResponse( return JSONResponse(status_code=400, content={"error": type(exc).__name__, "detail": str(exc)})
status_code=400,
content={"error": type(exc).__name__, "detail": str(exc)}
)
# ============================================================================ # ============================================================================
# MAIN # MAIN
# ============================================================================ # ============================================================================
if __name__ == '__main__': if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ import traceback
from pathlib import Path from pathlib import Path
# Ensure stegasoo is importable # Ensure stegasoo is importable
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src')) sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
@@ -66,7 +66,7 @@ def _get_channel_info(resolved_key):
# Auto mode - check server config # Auto mode - check server config
if has_channel_key(): if has_channel_key():
status = get_channel_status() status = get_channel_status()
return "private", status.get('fingerprint') return "private", status.get("fingerprint")
return "public", None return "public", None
@@ -76,62 +76,62 @@ def encode_operation(params: dict) -> dict:
from stegasoo import FilePayload, encode from stegasoo import FilePayload, encode
# Decode base64 inputs # Decode base64 inputs
carrier_data = base64.b64decode(params['carrier_b64']) carrier_data = base64.b64decode(params["carrier_b64"])
reference_data = base64.b64decode(params['reference_b64']) reference_data = base64.b64decode(params["reference_b64"])
# Optional RSA key # Optional RSA key
rsa_key_data = None rsa_key_data = None
if params.get('rsa_key_b64'): if params.get("rsa_key_b64"):
rsa_key_data = base64.b64decode(params['rsa_key_b64']) rsa_key_data = base64.b64decode(params["rsa_key_b64"])
# Determine payload type # Determine payload type
if params.get('file_b64'): if params.get("file_b64"):
file_data = base64.b64decode(params['file_b64']) file_data = base64.b64decode(params["file_b64"])
payload = FilePayload( payload = FilePayload(
data=file_data, data=file_data,
filename=params.get('file_name', 'file'), filename=params.get("file_name", "file"),
mime_type=params.get('file_mime', 'application/octet-stream'), mime_type=params.get("file_mime", "application/octet-stream"),
) )
else: else:
payload = params.get('message', '') payload = params.get("message", "")
# Resolve channel key (v4.0.0) # Resolve channel key (v4.0.0)
resolved_channel_key = _resolve_channel_key(params.get('channel_key', 'auto')) resolved_channel_key = _resolve_channel_key(params.get("channel_key", "auto"))
# Call encode with correct parameter names # Call encode with correct parameter names
result = encode( result = encode(
message=payload, message=payload,
reference_photo=reference_data, reference_photo=reference_data,
carrier_image=carrier_data, carrier_image=carrier_data,
passphrase=params.get('passphrase', ''), passphrase=params.get("passphrase", ""),
pin=params.get('pin'), pin=params.get("pin"),
rsa_key_data=rsa_key_data, rsa_key_data=rsa_key_data,
rsa_password=params.get('rsa_password'), rsa_password=params.get("rsa_password"),
embed_mode=params.get('embed_mode', 'lsb'), embed_mode=params.get("embed_mode", "lsb"),
dct_output_format=params.get('dct_output_format', 'png'), dct_output_format=params.get("dct_output_format", "png"),
dct_color_mode=params.get('dct_color_mode', 'color'), dct_color_mode=params.get("dct_color_mode", "color"),
channel_key=resolved_channel_key, # v4.0.0 channel_key=resolved_channel_key, # v4.0.0
) )
# Build stats dict if available # Build stats dict if available
stats = None stats = None
if hasattr(result, 'stats') and result.stats: if hasattr(result, "stats") and result.stats:
stats = { stats = {
'pixels_modified': getattr(result.stats, 'pixels_modified', 0), "pixels_modified": getattr(result.stats, "pixels_modified", 0),
'capacity_used': getattr(result.stats, 'capacity_used', 0), "capacity_used": getattr(result.stats, "capacity_used", 0),
'bytes_embedded': getattr(result.stats, 'bytes_embedded', 0), "bytes_embedded": getattr(result.stats, "bytes_embedded", 0),
} }
# Get channel info for response (v4.0.0) # Get channel info for response (v4.0.0)
channel_mode, channel_fingerprint = _get_channel_info(resolved_channel_key) channel_mode, channel_fingerprint = _get_channel_info(resolved_channel_key)
return { return {
'success': True, "success": True,
'stego_b64': base64.b64encode(result.stego_image).decode('ascii'), "stego_b64": base64.b64encode(result.stego_image).decode("ascii"),
'filename': getattr(result, 'filename', None), "filename": getattr(result, "filename", None),
'stats': stats, "stats": stats,
'channel_mode': channel_mode, "channel_mode": channel_mode,
'channel_fingerprint': channel_fingerprint, "channel_fingerprint": channel_fingerprint,
} }
@@ -140,42 +140,42 @@ def decode_operation(params: dict) -> dict:
from stegasoo import decode from stegasoo import decode
# Decode base64 inputs # Decode base64 inputs
stego_data = base64.b64decode(params['stego_b64']) stego_data = base64.b64decode(params["stego_b64"])
reference_data = base64.b64decode(params['reference_b64']) reference_data = base64.b64decode(params["reference_b64"])
# Optional RSA key # Optional RSA key
rsa_key_data = None rsa_key_data = None
if params.get('rsa_key_b64'): if params.get("rsa_key_b64"):
rsa_key_data = base64.b64decode(params['rsa_key_b64']) rsa_key_data = base64.b64decode(params["rsa_key_b64"])
# Resolve channel key (v4.0.0) # Resolve channel key (v4.0.0)
resolved_channel_key = _resolve_channel_key(params.get('channel_key', 'auto')) resolved_channel_key = _resolve_channel_key(params.get("channel_key", "auto"))
# Call decode with correct parameter names # Call decode with correct parameter names
result = decode( result = decode(
stego_image=stego_data, stego_image=stego_data,
reference_photo=reference_data, reference_photo=reference_data,
passphrase=params.get('passphrase', ''), passphrase=params.get("passphrase", ""),
pin=params.get('pin'), pin=params.get("pin"),
rsa_key_data=rsa_key_data, rsa_key_data=rsa_key_data,
rsa_password=params.get('rsa_password'), rsa_password=params.get("rsa_password"),
embed_mode=params.get('embed_mode', 'auto'), embed_mode=params.get("embed_mode", "auto"),
channel_key=resolved_channel_key, # v4.0.0 channel_key=resolved_channel_key, # v4.0.0
) )
if result.is_file: if result.is_file:
return { return {
'success': True, "success": True,
'is_file': True, "is_file": True,
'file_b64': base64.b64encode(result.file_data).decode('ascii'), "file_b64": base64.b64encode(result.file_data).decode("ascii"),
'filename': result.filename, "filename": result.filename,
'mime_type': result.mime_type, "mime_type": result.mime_type,
} }
else: else:
return { return {
'success': True, "success": True,
'is_file': False, "is_file": False,
'message': result.message, "message": result.message,
} }
@@ -183,12 +183,12 @@ def compare_operation(params: dict) -> dict:
"""Handle compare_modes operation.""" """Handle compare_modes operation."""
from stegasoo import compare_modes from stegasoo import compare_modes
carrier_data = base64.b64decode(params['carrier_b64']) carrier_data = base64.b64decode(params["carrier_b64"])
result = compare_modes(carrier_data) result = compare_modes(carrier_data)
return { return {
'success': True, "success": True,
'comparison': result, "comparison": result,
} }
@@ -196,17 +196,17 @@ def capacity_check_operation(params: dict) -> dict:
"""Handle will_fit_by_mode operation.""" """Handle will_fit_by_mode operation."""
from stegasoo import will_fit_by_mode from stegasoo import will_fit_by_mode
carrier_data = base64.b64decode(params['carrier_b64']) carrier_data = base64.b64decode(params["carrier_b64"])
result = will_fit_by_mode( result = will_fit_by_mode(
payload=params['payload_size'], payload=params["payload_size"],
carrier_image=carrier_data, carrier_image=carrier_data,
embed_mode=params.get('embed_mode', 'lsb'), embed_mode=params.get("embed_mode", "lsb"),
) )
return { return {
'success': True, "success": True,
'result': result, "result": result,
} }
@@ -215,17 +215,17 @@ def channel_status_operation(params: dict) -> dict:
from stegasoo import get_channel_status from stegasoo import get_channel_status
status = get_channel_status() status = get_channel_status()
reveal = params.get('reveal', False) reveal = params.get("reveal", False)
return { return {
'success': True, "success": True,
'status': { "status": {
'mode': status['mode'], "mode": status["mode"],
'configured': status['configured'], "configured": status["configured"],
'fingerprint': status.get('fingerprint'), "fingerprint": status.get("fingerprint"),
'source': status.get('source'), "source": status.get("source"),
'key': status.get('key') if reveal and status['configured'] else None, "key": status.get("key") if reveal and status["configured"] else None,
} },
} }
@@ -236,37 +236,37 @@ def main():
input_text = sys.stdin.read() input_text = sys.stdin.read()
if not input_text.strip(): if not input_text.strip():
output = {'success': False, 'error': 'No input provided'} output = {"success": False, "error": "No input provided"}
else: else:
params = json.loads(input_text) params = json.loads(input_text)
operation = params.get('operation') operation = params.get("operation")
if operation == 'encode': if operation == "encode":
output = encode_operation(params) output = encode_operation(params)
elif operation == 'decode': elif operation == "decode":
output = decode_operation(params) output = decode_operation(params)
elif operation == 'compare': elif operation == "compare":
output = compare_operation(params) output = compare_operation(params)
elif operation == 'capacity': elif operation == "capacity":
output = capacity_check_operation(params) output = capacity_check_operation(params)
elif operation == 'channel_status': elif operation == "channel_status":
output = channel_status_operation(params) output = channel_status_operation(params)
else: else:
output = {'success': False, 'error': f'Unknown operation: {operation}'} output = {"success": False, "error": f"Unknown operation: {operation}"}
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
output = {'success': False, 'error': f'Invalid JSON: {e}'} output = {"success": False, "error": f"Invalid JSON: {e}"}
except Exception as e: except Exception as e:
output = { output = {
'success': False, "success": False,
'error': str(e), "error": str(e),
'error_type': type(e).__name__, "error_type": type(e).__name__,
'traceback': traceback.format_exc(), "traceback": traceback.format_exc(),
} }
# Write output as JSON # Write output as JSON
print(json.dumps(output), flush=True) print(json.dumps(output), flush=True)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@@ -55,12 +55,13 @@ from typing import Any
DEFAULT_TIMEOUT = 120 DEFAULT_TIMEOUT = 120
# Path to worker script - adjust if needed # Path to worker script - adjust if needed
WORKER_SCRIPT = Path(__file__).parent / 'stego_worker.py' WORKER_SCRIPT = Path(__file__).parent / "stego_worker.py"
@dataclass @dataclass
class EncodeResult: class EncodeResult:
"""Result from encode operation.""" """Result from encode operation."""
success: bool success: bool
stego_data: bytes | None = None stego_data: bytes | None = None
filename: str | None = None filename: str | None = None
@@ -75,6 +76,7 @@ class EncodeResult:
@dataclass @dataclass
class DecodeResult: class DecodeResult:
"""Result from decode operation.""" """Result from decode operation."""
success: bool success: bool
is_file: bool = False is_file: bool = False
message: str | None = None message: str | None = None
@@ -88,6 +90,7 @@ class DecodeResult:
@dataclass @dataclass
class CompareResult: class CompareResult:
"""Result from compare_modes operation.""" """Result from compare_modes operation."""
success: bool success: bool
width: int = 0 width: int = 0
height: int = 0 height: int = 0
@@ -99,6 +102,7 @@ class CompareResult:
@dataclass @dataclass
class CapacityResult: class CapacityResult:
"""Result from capacity check operation.""" """Result from capacity check operation."""
success: bool success: bool
fits: bool = False fits: bool = False
payload_size: int = 0 payload_size: int = 0
@@ -112,6 +116,7 @@ class CapacityResult:
@dataclass @dataclass
class ChannelStatusResult: class ChannelStatusResult:
"""Result from channel status check (v4.0.0).""" """Result from channel status check (v4.0.0)."""
success: bool success: bool
mode: str = "public" mode: str = "public"
configured: bool = False configured: bool = False
@@ -177,37 +182,37 @@ class SubprocessStego:
if result.returncode != 0: if result.returncode != 0:
# Worker crashed # Worker crashed
return { return {
'success': False, "success": False,
'error': f'Worker crashed (exit code {result.returncode})', "error": f"Worker crashed (exit code {result.returncode})",
'stderr': result.stderr, "stderr": result.stderr,
} }
if not result.stdout.strip(): if not result.stdout.strip():
return { return {
'success': False, "success": False,
'error': 'Worker returned empty output', "error": "Worker returned empty output",
'stderr': result.stderr, "stderr": result.stderr,
} }
return json.loads(result.stdout) return json.loads(result.stdout)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return { return {
'success': False, "success": False,
'error': f'Operation timed out after {timeout} seconds', "error": f"Operation timed out after {timeout} seconds",
'error_type': 'TimeoutError', "error_type": "TimeoutError",
} }
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
return { return {
'success': False, "success": False,
'error': f'Invalid JSON from worker: {e}', "error": f"Invalid JSON from worker: {e}",
'raw_output': result.stdout if 'result' in dir() else None, "raw_output": result.stdout if "result" in dir() else None,
} }
except Exception as e: except Exception as e:
return { return {
'success': False, "success": False,
'error': str(e), "error": str(e),
'error_type': type(e).__name__, "error_type": type(e).__name__,
} }
def encode( def encode(
@@ -253,43 +258,43 @@ class SubprocessStego:
EncodeResult with stego_data and extension on success EncodeResult with stego_data and extension on success
""" """
params = { params = {
'operation': 'encode', "operation": "encode",
'carrier_b64': base64.b64encode(carrier_data).decode('ascii'), "carrier_b64": base64.b64encode(carrier_data).decode("ascii"),
'reference_b64': base64.b64encode(reference_data).decode('ascii'), "reference_b64": base64.b64encode(reference_data).decode("ascii"),
'message': message, "message": message,
'passphrase': passphrase, "passphrase": passphrase,
'pin': pin, "pin": pin,
'embed_mode': embed_mode, "embed_mode": embed_mode,
'dct_output_format': dct_output_format, "dct_output_format": dct_output_format,
'dct_color_mode': dct_color_mode, "dct_color_mode": dct_color_mode,
'channel_key': channel_key, # v4.0.0 "channel_key": channel_key, # v4.0.0
} }
if file_data: if file_data:
params['file_b64'] = base64.b64encode(file_data).decode('ascii') params["file_b64"] = base64.b64encode(file_data).decode("ascii")
params['file_name'] = file_name params["file_name"] = file_name
params['file_mime'] = file_mime params["file_mime"] = file_mime
if rsa_key_data: if rsa_key_data:
params['rsa_key_b64'] = base64.b64encode(rsa_key_data).decode('ascii') params["rsa_key_b64"] = base64.b64encode(rsa_key_data).decode("ascii")
params['rsa_password'] = rsa_password params["rsa_password"] = rsa_password
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get("success"):
return EncodeResult( return EncodeResult(
success=True, success=True,
stego_data=base64.b64decode(result['stego_b64']), stego_data=base64.b64decode(result["stego_b64"]),
filename=result.get('filename'), filename=result.get("filename"),
stats=result.get('stats'), stats=result.get("stats"),
channel_mode=result.get('channel_mode'), channel_mode=result.get("channel_mode"),
channel_fingerprint=result.get('channel_fingerprint'), channel_fingerprint=result.get("channel_fingerprint"),
) )
else: else:
return EncodeResult( return EncodeResult(
success=False, success=False,
error=result.get('error', 'Unknown error'), error=result.get("error", "Unknown error"),
error_type=result.get('error_type'), error_type=result.get("error_type"),
) )
def decode( def decode(
@@ -323,41 +328,41 @@ class SubprocessStego:
DecodeResult with message or file_data on success DecodeResult with message or file_data on success
""" """
params = { params = {
'operation': 'decode', "operation": "decode",
'stego_b64': base64.b64encode(stego_data).decode('ascii'), "stego_b64": base64.b64encode(stego_data).decode("ascii"),
'reference_b64': base64.b64encode(reference_data).decode('ascii'), "reference_b64": base64.b64encode(reference_data).decode("ascii"),
'passphrase': passphrase, "passphrase": passphrase,
'pin': pin, "pin": pin,
'embed_mode': embed_mode, "embed_mode": embed_mode,
'channel_key': channel_key, # v4.0.0 "channel_key": channel_key, # v4.0.0
} }
if rsa_key_data: if rsa_key_data:
params['rsa_key_b64'] = base64.b64encode(rsa_key_data).decode('ascii') params["rsa_key_b64"] = base64.b64encode(rsa_key_data).decode("ascii")
params['rsa_password'] = rsa_password params["rsa_password"] = rsa_password
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get("success"):
if result.get('is_file'): if result.get("is_file"):
return DecodeResult( return DecodeResult(
success=True, success=True,
is_file=True, is_file=True,
file_data=base64.b64decode(result['file_b64']), file_data=base64.b64decode(result["file_b64"]),
filename=result.get('filename'), filename=result.get("filename"),
mime_type=result.get('mime_type'), mime_type=result.get("mime_type"),
) )
else: else:
return DecodeResult( return DecodeResult(
success=True, success=True,
is_file=False, is_file=False,
message=result.get('message'), message=result.get("message"),
) )
else: else:
return DecodeResult( return DecodeResult(
success=False, success=False,
error=result.get('error', 'Unknown error'), error=result.get("error", "Unknown error"),
error_type=result.get('error_type'), error_type=result.get("error_type"),
) )
def compare_modes( def compare_modes(
@@ -376,25 +381,25 @@ class SubprocessStego:
CompareResult with capacity information CompareResult with capacity information
""" """
params = { params = {
'operation': 'compare', "operation": "compare",
'carrier_b64': base64.b64encode(carrier_data).decode('ascii'), "carrier_b64": base64.b64encode(carrier_data).decode("ascii"),
} }
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get("success"):
comparison = result.get('comparison', {}) comparison = result.get("comparison", {})
return CompareResult( return CompareResult(
success=True, success=True,
width=comparison.get('width', 0), width=comparison.get("width", 0),
height=comparison.get('height', 0), height=comparison.get("height", 0),
lsb=comparison.get('lsb'), lsb=comparison.get("lsb"),
dct=comparison.get('dct'), dct=comparison.get("dct"),
) )
else: else:
return CompareResult( return CompareResult(
success=False, success=False,
error=result.get('error', 'Unknown error'), error=result.get("error", "Unknown error"),
) )
def check_capacity( def check_capacity(
@@ -417,29 +422,29 @@ class SubprocessStego:
CapacityResult with fit information CapacityResult with fit information
""" """
params = { params = {
'operation': 'capacity', "operation": "capacity",
'carrier_b64': base64.b64encode(carrier_data).decode('ascii'), "carrier_b64": base64.b64encode(carrier_data).decode("ascii"),
'payload_size': payload_size, "payload_size": payload_size,
'embed_mode': embed_mode, "embed_mode": embed_mode,
} }
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get("success"):
r = result.get('result', {}) r = result.get("result", {})
return CapacityResult( return CapacityResult(
success=True, success=True,
fits=r.get('fits', False), fits=r.get("fits", False),
payload_size=r.get('payload_size', 0), payload_size=r.get("payload_size", 0),
capacity=r.get('capacity', 0), capacity=r.get("capacity", 0),
usage_percent=r.get('usage_percent', 0.0), usage_percent=r.get("usage_percent", 0.0),
headroom=r.get('headroom', 0), headroom=r.get("headroom", 0),
mode=r.get('mode', embed_mode), mode=r.get("mode", embed_mode),
) )
else: else:
return CapacityResult( return CapacityResult(
success=False, success=False,
error=result.get('error', 'Unknown error'), error=result.get("error", "Unknown error"),
) )
def get_channel_status( def get_channel_status(
@@ -458,26 +463,26 @@ class SubprocessStego:
ChannelStatusResult with channel info ChannelStatusResult with channel info
""" """
params = { params = {
'operation': 'channel_status', "operation": "channel_status",
'reveal': reveal, "reveal": reveal,
} }
result = self._run_worker(params, timeout) result = self._run_worker(params, timeout)
if result.get('success'): if result.get("success"):
status = result.get('status', {}) status = result.get("status", {})
return ChannelStatusResult( return ChannelStatusResult(
success=True, success=True,
mode=status.get('mode', 'public'), mode=status.get("mode", "public"),
configured=status.get('configured', False), configured=status.get("configured", False),
fingerprint=status.get('fingerprint'), fingerprint=status.get("fingerprint"),
source=status.get('source'), source=status.get("source"),
key=status.get('key') if reveal else None, key=status.get("key") if reveal else None,
) )
else: else:
return ChannelStatusResult( return ChannelStatusResult(
success=False, success=False,
error=result.get('error', 'Unknown error'), error=result.get("error", "Unknown error"),
) )

View File

@@ -22,6 +22,7 @@ def main():
""" """
try: try:
from stegasoo.cli import main as cli_main from stegasoo.cli import main as cli_main
cli_main() cli_main()
except ImportError as e: except ImportError as e:
# Provide helpful error if dependencies are missing # Provide helpful error if dependencies are missing
@@ -43,6 +44,7 @@ def version():
"""Print version and exit.""" """Print version and exit."""
try: try:
from stegasoo import __version__ from stegasoo import __version__
print(f"stegasoo {__version__}") print(f"stegasoo {__version__}")
except ImportError: except ImportError:
print("stegasoo (version unknown)") print("stegasoo (version unknown)")

View File

@@ -60,6 +60,7 @@ try:
extract_key_from_qr, extract_key_from_qr,
generate_qr_code, generate_qr_code,
) )
HAS_QR_UTILS = True HAS_QR_UTILS = True
except ImportError: except ImportError:
HAS_QR_UTILS = False HAS_QR_UTILS = False
@@ -151,13 +152,11 @@ DCT_BYTES_PER_PIXEL = 0.125
__all__ = [ __all__ = [
# Version # Version
"__version__", "__version__",
# Core # Core
"encode", "encode",
"decode", "decode",
"decode_file", "decode_file",
"decode_text", "decode_text",
# Generation # Generation
"generate_pin", "generate_pin",
"generate_passphrase", "generate_passphrase",
@@ -165,7 +164,6 @@ __all__ = [
"generate_credentials", "generate_credentials",
"export_rsa_key_pem", "export_rsa_key_pem",
"load_rsa_key", "load_rsa_key",
# Channel key management (v4.0.0) # Channel key management (v4.0.0)
"generate_channel_key", "generate_channel_key",
"get_channel_key", "get_channel_key",
@@ -177,28 +175,22 @@ __all__ = [
"format_channel_key", "format_channel_key",
"get_active_channel_key", "get_active_channel_key",
"get_channel_fingerprint", "get_channel_fingerprint",
# Image utilities # Image utilities
"get_image_info", "get_image_info",
"compare_capacity", "compare_capacity",
# Utilities # Utilities
"generate_filename", "generate_filename",
# Crypto # Crypto
"has_argon2", "has_argon2",
# Steganography # Steganography
"has_dct_support", "has_dct_support",
"compare_modes", "compare_modes",
"will_fit_by_mode", "will_fit_by_mode",
# QR utilities # QR utilities
"generate_qr_code", "generate_qr_code",
"extract_key_from_qr", "extract_key_from_qr",
"detect_and_crop_qr", "detect_and_crop_qr",
"HAS_QR_UTILS", "HAS_QR_UTILS",
# Validation # Validation
"validate_reference_photo", "validate_reference_photo",
"validate_carrier", "validate_carrier",
@@ -212,7 +204,6 @@ __all__ = [
"validate_dct_output_format", "validate_dct_output_format",
"validate_dct_color_mode", "validate_dct_color_mode",
"validate_channel_key", "validate_channel_key",
# Models # Models
"ImageInfo", "ImageInfo",
"CapacityComparison", "CapacityComparison",
@@ -222,7 +213,6 @@ __all__ = [
"FilePayload", "FilePayload",
"Credentials", "Credentials",
"ValidationResult", "ValidationResult",
# Exceptions # Exceptions
"StegasooError", "StegasooError",
"ValidationError", "ValidationError",
@@ -242,7 +232,6 @@ __all__ = [
"ExtractionError", "ExtractionError",
"EmbeddingError", "EmbeddingError",
"InvalidHeaderError", "InvalidHeaderError",
# Constants # Constants
"FORMAT_VERSION", "FORMAT_VERSION",
"MIN_PASSPHRASE_WORDS", "MIN_PASSPHRASE_WORDS",

View File

@@ -23,6 +23,7 @@ from .constants import ALLOWED_IMAGE_EXTENSIONS, LOSSLESS_FORMATS
class BatchStatus(Enum): class BatchStatus(Enum):
"""Status of individual batch items.""" """Status of individual batch items."""
PENDING = "pending" PENDING = "pending"
PROCESSING = "processing" PROCESSING = "processing"
SUCCESS = "success" SUCCESS = "success"
@@ -33,6 +34,7 @@ class BatchStatus(Enum):
@dataclass @dataclass
class BatchItem: class BatchItem:
"""Represents a single item in a batch operation.""" """Represents a single item in a batch operation."""
input_path: Path input_path: Path
output_path: Path | None = None output_path: Path | None = None
status: BatchStatus = BatchStatus.PENDING status: BatchStatus = BatchStatus.PENDING
@@ -84,6 +86,7 @@ class BatchCredentials:
) )
result = processor.batch_encode(images, creds, message="secret") result = processor.batch_encode(images, creds, message="secret")
""" """
reference_photo: bytes reference_photo: bytes
passphrase: str # v3.2.0: renamed from day_phrase passphrase: str # v3.2.0: renamed from day_phrase
pin: str = "" pin: str = ""
@@ -101,27 +104,28 @@ class BatchCredentials:
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'BatchCredentials': def from_dict(cls, data: dict) -> "BatchCredentials":
""" """
Create BatchCredentials from a dictionary. Create BatchCredentials from a dictionary.
Handles both v3.2.0 format (passphrase) and legacy formats (day_phrase, phrase). Handles both v3.2.0 format (passphrase) and legacy formats (day_phrase, phrase).
""" """
# Handle legacy 'day_phrase' and 'phrase' keys # Handle legacy 'day_phrase' and 'phrase' keys
passphrase = data.get('passphrase') or data.get('day_phrase') or data.get('phrase', '') passphrase = data.get("passphrase") or data.get("day_phrase") or data.get("phrase", "")
return cls( return cls(
reference_photo=data['reference_photo'], reference_photo=data["reference_photo"],
passphrase=passphrase, passphrase=passphrase,
pin=data.get('pin', ''), pin=data.get("pin", ""),
rsa_key_data=data.get('rsa_key_data'), rsa_key_data=data.get("rsa_key_data"),
rsa_password=data.get('rsa_password'), rsa_password=data.get("rsa_password"),
) )
@dataclass @dataclass
class BatchResult: class BatchResult:
"""Summary of a batch operation.""" """Summary of a batch operation."""
operation: str operation: str
total: int = 0 total: int = 0
succeeded: int = 0 succeeded: int = 0
@@ -232,18 +236,17 @@ class BatchProcessor:
yield path yield path
elif path.is_dir(): elif path.is_dir():
pattern = '**/*' if recursive else '*' pattern = "**/*" if recursive else "*"
for file_path in path.glob(pattern): for file_path in path.glob(pattern):
if file_path.is_file() and self._is_valid_image(file_path): if file_path.is_file() and self._is_valid_image(file_path):
yield file_path yield file_path
def _is_valid_image(self, path: Path) -> bool: def _is_valid_image(self, path: Path) -> bool:
"""Check if path is a valid image file.""" """Check if path is a valid image file."""
return path.suffix.lower().lstrip('.') in ALLOWED_IMAGE_EXTENSIONS return path.suffix.lower().lstrip(".") in ALLOWED_IMAGE_EXTENSIONS
def _normalize_credentials( def _normalize_credentials(
self, self, credentials: dict | BatchCredentials | None
credentials: dict | BatchCredentials | None
) -> BatchCredentials: ) -> BatchCredentials:
""" """
Normalize credentials to BatchCredentials object. Normalize credentials to BatchCredentials object.
@@ -341,7 +344,11 @@ class BatchProcessor:
self._do_encode(item, message, file_payload, creds, compress) self._do_encode(item, message, file_payload, creds, compress)
item.status = BatchStatus.SUCCESS item.status = BatchStatus.SUCCESS
item.output_size = item.output_path.stat().st_size if item.output_path and item.output_path.exists() else 0 item.output_size = (
item.output_path.stat().st_size
if item.output_path and item.output_path.exists()
else 0
)
item.message = f"Encoded to {item.output_path.name}" item.message = f"Encoded to {item.output_path.name}"
except Exception as e: except Exception as e:
@@ -412,7 +419,9 @@ class BatchProcessor:
output_dir=item.output_path, output_dir=item.output_path,
credentials=creds.to_dict(), credentials=creds.to_dict(),
) )
item.message = decoded.get('message', '') if isinstance(decoded, dict) else str(decoded) item.message = (
decoded.get("message", "") if isinstance(decoded, dict) else str(decoded)
)
else: else:
# Use stegasoo decode # Use stegasoo decode
item.message = self._do_decode(item, creds) item.message = self._do_decode(item, creds)
@@ -441,10 +450,7 @@ class BatchProcessor:
completed = 0 completed = 0
with ThreadPoolExecutor(max_workers=self.max_workers) as executor: with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = { futures = {executor.submit(process_func, item): item for item in result.items}
executor.submit(process_func, item): item
for item in result.items
}
for future in as_completed(futures): for future in as_completed(futures):
item = future.result() item = future.result()
@@ -469,7 +475,7 @@ class BatchProcessor:
message: str | None, message: str | None,
file_payload: Path | None, file_payload: Path | None,
creds: BatchCredentials, creds: BatchCredentials,
compress: bool compress: bool,
) -> None: ) -> None:
""" """
Perform actual encoding using stegasoo.encode. Perform actual encoding using stegasoo.encode.
@@ -555,16 +561,13 @@ class BatchProcessor:
return self._mock_decode(item, creds) return self._mock_decode(item, creds)
def _mock_encode( def _mock_encode(
self, self, item: BatchItem, message: str, creds: BatchCredentials, compress: bool
item: BatchItem,
message: str,
creds: BatchCredentials,
compress: bool
) -> None: ) -> None:
"""Mock encode for testing - replace with actual stego.encode()""" """Mock encode for testing - replace with actual stego.encode()"""
# This is a placeholder - in real usage, you'd call your actual encode function # This is a placeholder - in real usage, you'd call your actual encode function
# For now, just copy the file to simulate encoding # For now, just copy the file to simulate encoding
import shutil import shutil
if item.output_path: if item.output_path:
shutil.copy(item.input_path, item.output_path) shutil.copy(item.input_path, item.output_path)
@@ -605,7 +608,8 @@ def batch_capacity_check(
capacity_bits = pixels * 3 capacity_bits = pixels * 3
capacity_bytes = (capacity_bits // 8) - 100 # Header overhead capacity_bytes = (capacity_bits // 8) - 100 # Header overhead
results.append({ results.append(
{
"path": str(img_path), "path": str(img_path),
"dimensions": f"{width}x{height}", "dimensions": f"{width}x{height}",
"pixels": pixels, "pixels": pixels,
@@ -615,13 +619,16 @@ def batch_capacity_check(
"capacity_kb": max(0, capacity_bytes // 1024), "capacity_kb": max(0, capacity_bytes // 1024),
"valid": pixels <= MAX_IMAGE_PIXELS and img.format in LOSSLESS_FORMATS, "valid": pixels <= MAX_IMAGE_PIXELS and img.format in LOSSLESS_FORMATS,
"warnings": _get_image_warnings(img, img_path), "warnings": _get_image_warnings(img, img_path),
}) }
)
except Exception as e: except Exception as e:
results.append({ results.append(
{
"path": str(img_path), "path": str(img_path),
"error": str(e), "error": str(e),
"valid": False, "valid": False,
}) }
)
return results return results
@@ -638,7 +645,7 @@ def _get_image_warnings(img, path: Path) -> list[str]:
if img.size[0] * img.size[1] > MAX_IMAGE_PIXELS: if img.size[0] * img.size[1] > MAX_IMAGE_PIXELS:
warnings.append(f"Image exceeds {MAX_IMAGE_PIXELS:,} pixel limit") warnings.append(f"Image exceeds {MAX_IMAGE_PIXELS:,} pixel limit")
if img.mode not in ('RGB', 'RGBA'): if img.mode not in ("RGB", "RGBA"):
warnings.append(f"Non-RGB mode ({img.mode}) - will be converted") warnings.append(f"Non-RGB mode ({img.mode}) - will be converted")
return warnings return warnings
@@ -646,6 +653,7 @@ def _get_image_warnings(img, path: Path) -> list[str]:
# CLI-friendly functions # CLI-friendly functions
def print_batch_result(result: BatchResult, verbose: bool = False) -> None: def print_batch_result(result: BatchResult, verbose: bool = False) -> None:
"""Print batch result summary to console.""" """Print batch result summary to console."""
print(f"\n{'='*60}") print(f"\n{'='*60}")

View File

@@ -34,17 +34,17 @@ from .debug import debug
# Channel key format: 8 groups of 4 alphanumeric chars (32 chars total) # Channel key format: 8 groups of 4 alphanumeric chars (32 chars total)
# Example: ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456 # Example: ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456
CHANNEL_KEY_PATTERN = re.compile(r'^[A-Z0-9]{4}(-[A-Z0-9]{4}){7}$') CHANNEL_KEY_PATTERN = re.compile(r"^[A-Z0-9]{4}(-[A-Z0-9]{4}){7}$")
CHANNEL_KEY_LENGTH = 32 # Characters (excluding dashes) CHANNEL_KEY_LENGTH = 32 # Characters (excluding dashes)
CHANNEL_KEY_FORMATTED_LENGTH = 39 # With dashes CHANNEL_KEY_FORMATTED_LENGTH = 39 # With dashes
# Environment variable name # Environment variable name
CHANNEL_KEY_ENV_VAR = 'STEGASOO_CHANNEL_KEY' CHANNEL_KEY_ENV_VAR = "STEGASOO_CHANNEL_KEY"
# Config locations (in priority order) # Config locations (in priority order)
CONFIG_LOCATIONS = [ CONFIG_LOCATIONS = [
Path('./config/channel.key'), # Project config Path("./config/channel.key"), # Project config
Path.home() / '.stegasoo' / 'channel.key', # User config Path.home() / ".stegasoo" / "channel.key", # User config
] ]
@@ -61,8 +61,8 @@ def generate_channel_key() -> str:
39 39
""" """
# Generate 32 random alphanumeric characters # Generate 32 random alphanumeric characters
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
raw_key = ''.join(secrets.choice(alphabet) for _ in range(CHANNEL_KEY_LENGTH)) raw_key = "".join(secrets.choice(alphabet) for _ in range(CHANNEL_KEY_LENGTH))
formatted = format_channel_key(raw_key) formatted = format_channel_key(raw_key)
debug.print(f"Generated channel key: {get_channel_fingerprint(formatted)}") debug.print(f"Generated channel key: {get_channel_fingerprint(formatted)}")
@@ -87,19 +87,17 @@ def format_channel_key(raw_key: str) -> str:
"ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456" "ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456"
""" """
# Remove any existing dashes, spaces, and convert to uppercase # Remove any existing dashes, spaces, and convert to uppercase
clean = raw_key.replace('-', '').replace(' ', '').upper() clean = raw_key.replace("-", "").replace(" ", "").upper()
if len(clean) != CHANNEL_KEY_LENGTH: if len(clean) != CHANNEL_KEY_LENGTH:
raise ValueError( raise ValueError(f"Channel key must be {CHANNEL_KEY_LENGTH} characters (got {len(clean)})")
f"Channel key must be {CHANNEL_KEY_LENGTH} characters (got {len(clean)})"
)
# Validate characters # Validate characters
if not all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' for c in clean): if not all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" for c in clean):
raise ValueError("Channel key must contain only letters A-Z and digits 0-9") raise ValueError("Channel key must contain only letters A-Z and digits 0-9")
# Format with dashes every 4 characters # Format with dashes every 4 characters
return '-'.join(clean[i:i+4] for i in range(0, CHANNEL_KEY_LENGTH, 4)) return "-".join(clean[i : i + 4] for i in range(0, CHANNEL_KEY_LENGTH, 4))
def validate_channel_key(key: str) -> bool: def validate_channel_key(key: str) -> bool:
@@ -148,7 +146,7 @@ def get_channel_key() -> str | None:
... print("Public mode") ... print("Public mode")
""" """
# 1. Check environment variable # 1. Check environment variable
env_key = os.environ.get(CHANNEL_KEY_ENV_VAR, '').strip() env_key = os.environ.get(CHANNEL_KEY_ENV_VAR, "").strip()
if env_key: if env_key:
if validate_channel_key(env_key): if validate_channel_key(env_key):
debug.print(f"Channel key from environment: {get_channel_fingerprint(env_key)}") debug.print(f"Channel key from environment: {get_channel_fingerprint(env_key)}")
@@ -173,7 +171,7 @@ def get_channel_key() -> str | None:
return None return None
def set_channel_key(key: str, location: str = 'project') -> Path: def set_channel_key(key: str, location: str = "project") -> Path:
""" """
Save a channel key to config file. Save a channel key to config file.
@@ -194,16 +192,16 @@ def set_channel_key(key: str, location: str = 'project') -> Path:
""" """
formatted = format_channel_key(key) formatted = format_channel_key(key)
if location == 'user': if location == "user":
config_path = Path.home() / '.stegasoo' / 'channel.key' config_path = Path.home() / ".stegasoo" / "channel.key"
else: else:
config_path = Path('./config/channel.key') config_path = Path("./config/channel.key")
# Create directory if needed # Create directory if needed
config_path.parent.mkdir(parents=True, exist_ok=True) config_path.parent.mkdir(parents=True, exist_ok=True)
# Write key with newline # Write key with newline
config_path.write_text(formatted + '\n') config_path.write_text(formatted + "\n")
# Set restrictive permissions (owner read/write only) # Set restrictive permissions (owner read/write only)
try: try:
@@ -215,7 +213,7 @@ def set_channel_key(key: str, location: str = 'project') -> Path:
return config_path return config_path
def clear_channel_key(location: str = 'all') -> list[Path]: def clear_channel_key(location: str = "all") -> list[Path]:
""" """
Remove channel key configuration. Remove channel key configuration.
@@ -232,10 +230,10 @@ def clear_channel_key(location: str = 'all') -> list[Path]:
deleted = [] deleted = []
paths_to_check = [] paths_to_check = []
if location in ('project', 'all'): if location in ("project", "all"):
paths_to_check.append(Path('./config/channel.key')) paths_to_check.append(Path("./config/channel.key"))
if location in ('user', 'all'): if location in ("user", "all"):
paths_to_check.append(Path.home() / '.stegasoo' / 'channel.key') paths_to_check.append(Path.home() / ".stegasoo" / "channel.key")
for path in paths_to_check: for path in paths_to_check:
if path.exists(): if path.exists():
@@ -275,7 +273,7 @@ def get_channel_key_hash(key: str | None = None) -> bytes | None:
# Hash the formatted key to get consistent 32 bytes # Hash the formatted key to get consistent 32 bytes
formatted = format_channel_key(key) formatted = format_channel_key(key)
return hashlib.sha256(formatted.encode('utf-8')).digest() return hashlib.sha256(formatted.encode("utf-8")).digest()
def get_channel_fingerprint(key: str | None = None) -> str | None: def get_channel_fingerprint(key: str | None = None) -> str | None:
@@ -300,11 +298,11 @@ def get_channel_fingerprint(key: str | None = None) -> str | None:
return None return None
formatted = format_channel_key(key) formatted = format_channel_key(key)
parts = formatted.split('-') parts = formatted.split("-")
# Show first and last group, mask the rest # Show first and last group, mask the rest
masked = [parts[0]] + ['••••'] * 6 + [parts[-1]] masked = [parts[0]] + ["••••"] * 6 + [parts[-1]]
return '-'.join(masked) return "-".join(masked)
def get_channel_status() -> dict: def get_channel_status() -> dict:
@@ -328,10 +326,10 @@ def get_channel_status() -> dict:
if key: if key:
# Find which source provided the key # Find which source provided the key
source = 'unknown' source = "unknown"
env_key = os.environ.get(CHANNEL_KEY_ENV_VAR, '').strip() env_key = os.environ.get(CHANNEL_KEY_ENV_VAR, "").strip()
if env_key and validate_channel_key(env_key): if env_key and validate_channel_key(env_key):
source = 'environment' source = "environment"
else: else:
for config_path in CONFIG_LOCATIONS: for config_path in CONFIG_LOCATIONS:
if config_path.exists(): if config_path.exists():
@@ -344,19 +342,19 @@ def get_channel_status() -> dict:
continue continue
return { return {
'mode': 'private', "mode": "private",
'configured': True, "configured": True,
'fingerprint': get_channel_fingerprint(key), "fingerprint": get_channel_fingerprint(key),
'source': source, "source": source,
'key': key, "key": key,
} }
else: else:
return { return {
'mode': 'public', "mode": "public",
'configured': False, "configured": False,
'fingerprint': None, "fingerprint": None,
'source': None, "source": None,
'key': None, "key": None,
} }
@@ -378,14 +376,14 @@ def has_channel_key() -> bool:
# CLI SUPPORT # CLI SUPPORT
# ============================================================================= # =============================================================================
if __name__ == '__main__': if __name__ == "__main__":
import sys import sys
def print_status(): def print_status():
"""Print current channel status.""" """Print current channel status."""
status = get_channel_status() status = get_channel_status()
print(f"Mode: {status['mode'].upper()}") print(f"Mode: {status['mode'].upper()}")
if status['configured']: if status["configured"]:
print(f"Fingerprint: {status['fingerprint']}") print(f"Fingerprint: {status['fingerprint']}")
print(f"Source: {status['source']}") print(f"Source: {status['source']}")
else: else:
@@ -406,17 +404,17 @@ if __name__ == '__main__':
cmd = sys.argv[1].lower() cmd = sys.argv[1].lower()
if cmd == 'generate': if cmd == "generate":
key = generate_channel_key() key = generate_channel_key()
print("Generated channel key:") print("Generated channel key:")
print(f" {key}") print(f" {key}")
print() print()
save = input("Save to config? [y/N]: ").strip().lower() save = input("Save to config? [y/N]: ").strip().lower()
if save == 'y': if save == "y":
path = set_channel_key(key) path = set_channel_key(key)
print(f"Saved to: {path}") print(f"Saved to: {path}")
elif cmd == 'set': elif cmd == "set":
if len(sys.argv) < 3: if len(sys.argv) < 3:
print("Usage: python -m stegasoo.channel set <KEY>") print("Usage: python -m stegasoo.channel set <KEY>")
sys.exit(1) sys.exit(1)
@@ -431,22 +429,22 @@ if __name__ == '__main__':
print(f"Error: {e}") print(f"Error: {e}")
sys.exit(1) sys.exit(1)
elif cmd == 'show': elif cmd == "show":
status = get_channel_status() status = get_channel_status()
if status['configured']: if status["configured"]:
print(f"Channel key: {status['key']}") print(f"Channel key: {status['key']}")
print(f"Source: {status['source']}") print(f"Source: {status['source']}")
else: else:
print("No channel key configured") print("No channel key configured")
elif cmd == 'clear': elif cmd == "clear":
deleted = clear_channel_key('all') deleted = clear_channel_key("all")
if deleted: if deleted:
print(f"Removed channel key from: {', '.join(str(p) for p in deleted)}") print(f"Removed channel key from: {', '.join(str(p) for p in deleted)}")
else: else:
print("No channel key files found") print("No channel key files found")
elif cmd == 'status': elif cmd == "status":
print_status() print_status()
else: else:

View File

@@ -33,12 +33,12 @@ from .constants import (
) )
# Click context settings # Click context settings
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
@click.group(context_settings=CONTEXT_SETTINGS) @click.group(context_settings=CONTEXT_SETTINGS)
@click.version_option(__version__, '-v', '--version') @click.version_option(__version__, "-v", "--version")
@click.option('--json', 'json_output', is_flag=True, help='Output results as JSON') @click.option("--json", "json_output", is_flag=True, help="Output results as JSON")
@click.pass_context @click.pass_context
def cli(ctx, json_output): def cli(ctx, json_output):
""" """
@@ -47,31 +47,47 @@ def cli(ctx, json_output):
Hide messages in images using PIN + passphrase security. Hide messages in images using PIN + passphrase security.
""" """
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj['json'] = json_output ctx.obj["json"] = json_output
# ============================================================================= # =============================================================================
# ENCODE COMMANDS # ENCODE COMMANDS
# ============================================================================= # =============================================================================
@cli.command() @cli.command()
@click.argument('image', type=click.Path(exists=True)) @click.argument("image", type=click.Path(exists=True))
@click.option('-m', '--message', help='Message to encode') @click.option("-m", "--message", help="Message to encode")
@click.option('-f', '--file', 'file_payload', type=click.Path(exists=True), @click.option(
help='File to embed instead of message') "-f",
@click.option('-o', '--output', type=click.Path(), help='Output image path') "--file",
@click.option('--passphrase', prompt=True, hide_input=True, "file_payload",
confirmation_prompt=True, help='Passphrase (recommend 4+ words)') type=click.Path(exists=True),
@click.option('--pin', prompt=True, hide_input=True, help="File to embed instead of message",
confirmation_prompt=True, help='PIN code') )
@click.option('--compress/--no-compress', default=True, @click.option("-o", "--output", type=click.Path(), help="Output image path")
help='Enable/disable compression (default: enabled)') @click.option(
@click.option('--algorithm', type=click.Choice(['zlib', 'lz4', 'none']), "--passphrase",
default='zlib', help='Compression algorithm') prompt=True,
@click.option('--dry-run', is_flag=True, help='Show capacity usage without encoding') hide_input=True,
confirmation_prompt=True,
help="Passphrase (recommend 4+ words)",
)
@click.option("--pin", prompt=True, hide_input=True, confirmation_prompt=True, help="PIN code")
@click.option(
"--compress/--no-compress", default=True, help="Enable/disable compression (default: enabled)"
)
@click.option(
"--algorithm",
type=click.Choice(["zlib", "lz4", "none"]),
default="zlib",
help="Compression algorithm",
)
@click.option("--dry-run", is_flag=True, help="Show capacity usage without encoding")
@click.pass_context @click.pass_context
def encode(ctx, image, message, file_payload, output, passphrase, pin, def encode(
compress, algorithm, dry_run): ctx, image, message, file_payload, output, passphrase, pin, compress, algorithm, dry_run
):
""" """
Encode a message or file into an image. Encode a message or file into an image.
@@ -88,13 +104,13 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
# Parse compression algorithm # Parse compression algorithm
algo_map = { algo_map = {
'zlib': CompressionAlgorithm.ZLIB, "zlib": CompressionAlgorithm.ZLIB,
'lz4': CompressionAlgorithm.LZ4, "lz4": CompressionAlgorithm.LZ4,
'none': CompressionAlgorithm.NONE, "none": CompressionAlgorithm.NONE,
} }
compression_algo = algo_map[algorithm] if compress else CompressionAlgorithm.NONE compression_algo = algo_map[algorithm] if compress else CompressionAlgorithm.NONE
if algorithm == 'lz4' and not HAS_LZ4: if algorithm == "lz4" and not HAS_LZ4:
click.echo("Warning: LZ4 not available, falling back to zlib", err=True) click.echo("Warning: LZ4 not available, falling back to zlib", err=True)
compression_algo = CompressionAlgorithm.ZLIB compression_algo = CompressionAlgorithm.ZLIB
@@ -103,7 +119,7 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
payload_size = Path(file_payload).stat().st_size payload_size = Path(file_payload).stat().st_size
payload_type = "file" payload_type = "file"
else: else:
payload_size = len(message.encode('utf-8')) payload_size = len(message.encode("utf-8"))
payload_type = "text" payload_type = "text"
# Get image capacity # Get image capacity
@@ -123,7 +139,7 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
"fits": payload_size < capacity_bytes, "fits": payload_size < capacity_bytes,
} }
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(json.dumps(result, indent=2)) click.echo(json.dumps(result, indent=2))
else: else:
click.echo(f"Image: {image} ({width}x{height})") click.echo(f"Image: {image} ({width}x{height})")
@@ -138,25 +154,29 @@ def encode(ctx, image, message, file_payload, output, passphrase, pin,
# For now, show what would be done # For now, show what would be done
output = output or f"{Path(image).stem}_encoded.png" output = output or f"{Path(image).stem}_encoded.png"
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(json.dumps({ click.echo(
json.dumps(
{
"status": "success", "status": "success",
"input": image, "input": image,
"output": output, "output": output,
"payload_type": payload_type, "payload_type": payload_type,
"compression": algorithm_name(compression_algo), "compression": algorithm_name(compression_algo),
}, indent=2)) },
indent=2,
)
)
else: else:
click.echo(f"✓ Encoded {payload_type} to {output}") click.echo(f"✓ Encoded {payload_type} to {output}")
click.echo(f" Compression: {algorithm_name(compression_algo)}") click.echo(f" Compression: {algorithm_name(compression_algo)}")
@cli.command() @cli.command()
@click.argument('image', type=click.Path(exists=True)) @click.argument("image", type=click.Path(exists=True))
@click.option('--passphrase', prompt=True, hide_input=True, help='Passphrase') @click.option("--passphrase", prompt=True, hide_input=True, help="Passphrase")
@click.option('--pin', prompt=True, hide_input=True, help='PIN code') @click.option("--pin", prompt=True, hide_input=True, help="PIN code")
@click.option('-o', '--output', type=click.Path(), @click.option("-o", "--output", type=click.Path(), help="Output path for file payloads")
help='Output path for file payloads')
@click.pass_context @click.pass_context
def decode(ctx, image, passphrase, pin, output): def decode(ctx, image, passphrase, pin, output):
""" """
@@ -176,46 +196,68 @@ def decode(ctx, image, passphrase, pin, output):
"message": "[Decoded message would appear here]", "message": "[Decoded message would appear here]",
} }
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(json.dumps(result, indent=2)) click.echo(json.dumps(result, indent=2))
else: else:
click.echo(f"Decoded from {image}:") click.echo(f"Decoded from {image}:")
click.echo(result['message']) click.echo(result["message"])
# ============================================================================= # =============================================================================
# BATCH COMMANDS # BATCH COMMANDS
# ============================================================================= # =============================================================================
@cli.group() @cli.group()
def batch(): def batch():
"""Batch operations on multiple images.""" """Batch operations on multiple images."""
pass pass
@batch.command('encode') @batch.command("encode")
@click.argument('images', nargs=-1, required=True, type=click.Path(exists=True)) @click.argument("images", nargs=-1, required=True, type=click.Path(exists=True))
@click.option('-m', '--message', help='Message to encode in all images') @click.option("-m", "--message", help="Message to encode in all images")
@click.option('-f', '--file', 'file_payload', type=click.Path(exists=True), @click.option(
help='File to embed in all images') "-f", "--file", "file_payload", type=click.Path(exists=True), help="File to embed in all images"
@click.option('-o', '--output-dir', type=click.Path(), )
help='Output directory (default: same as input)') @click.option(
@click.option('--suffix', default='_encoded', help='Output filename suffix') "-o", "--output-dir", type=click.Path(), help="Output directory (default: same as input)"
@click.option('--passphrase', prompt=True, hide_input=True, )
confirmation_prompt=True, help='Passphrase (recommend 4+ words)') @click.option("--suffix", default="_encoded", help="Output filename suffix")
@click.option('--pin', prompt=True, hide_input=True, @click.option(
confirmation_prompt=True, help='PIN code') "--passphrase",
@click.option('--compress/--no-compress', default=True, prompt=True,
help='Enable/disable compression') hide_input=True,
@click.option('--algorithm', type=click.Choice(['zlib', 'lz4', 'none']), confirmation_prompt=True,
default='zlib', help='Compression algorithm') help="Passphrase (recommend 4+ words)",
@click.option('-r', '--recursive', is_flag=True, )
help='Search directories recursively') @click.option("--pin", prompt=True, hide_input=True, confirmation_prompt=True, help="PIN code")
@click.option('-j', '--jobs', default=4, help='Parallel workers (default: 4)') @click.option("--compress/--no-compress", default=True, help="Enable/disable compression")
@click.option('-v', '--verbose', is_flag=True, help='Show detailed output') @click.option(
"--algorithm",
type=click.Choice(["zlib", "lz4", "none"]),
default="zlib",
help="Compression algorithm",
)
@click.option("-r", "--recursive", is_flag=True, help="Search directories recursively")
@click.option("-j", "--jobs", default=4, help="Parallel workers (default: 4)")
@click.option("-v", "--verbose", is_flag=True, help="Show detailed output")
@click.pass_context @click.pass_context
def batch_encode(ctx, images, message, file_payload, output_dir, suffix, def batch_encode(
passphrase, pin, compress, algorithm, recursive, jobs, verbose): ctx,
images,
message,
file_payload,
output_dir,
suffix,
passphrase,
pin,
compress,
algorithm,
recursive,
jobs,
verbose,
):
""" """
Encode message into multiple images. Encode message into multiple images.
@@ -232,7 +274,7 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
# Progress callback # Progress callback
def progress(current, total, item): def progress(current, total, item):
if not ctx.obj.get('json'): if not ctx.obj.get("json"):
status = "" if item.status.value == "success" else "" status = "" if item.status.value == "success" else ""
click.echo(f"[{current}/{total}] {status} {item.input_path.name}") click.echo(f"[{current}/{total}] {status} {item.input_path.name}")
@@ -248,25 +290,23 @@ def batch_encode(ctx, images, message, file_payload, output_dir, suffix,
credentials=credentials, credentials=credentials,
compress=compress, compress=compress,
recursive=recursive, recursive=recursive,
progress_callback=progress if not ctx.obj.get('json') else None, progress_callback=progress if not ctx.obj.get("json") else None,
) )
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(result.to_json()) click.echo(result.to_json())
else: else:
print_batch_result(result, verbose) print_batch_result(result, verbose)
@batch.command('decode') @batch.command("decode")
@click.argument('images', nargs=-1, required=True, type=click.Path(exists=True)) @click.argument("images", nargs=-1, required=True, type=click.Path(exists=True))
@click.option('-o', '--output-dir', type=click.Path(), @click.option("-o", "--output-dir", type=click.Path(), help="Output directory for file payloads")
help='Output directory for file payloads') @click.option("--passphrase", prompt=True, hide_input=True, help="Passphrase")
@click.option('--passphrase', prompt=True, hide_input=True, help='Passphrase') @click.option("--pin", prompt=True, hide_input=True, help="PIN code")
@click.option('--pin', prompt=True, hide_input=True, help='PIN code') @click.option("-r", "--recursive", is_flag=True, help="Search directories recursively")
@click.option('-r', '--recursive', is_flag=True, @click.option("-j", "--jobs", default=4, help="Parallel workers (default: 4)")
help='Search directories recursively') @click.option("-v", "--verbose", is_flag=True, help="Show detailed output")
@click.option('-j', '--jobs', default=4, help='Parallel workers (default: 4)')
@click.option('-v', '--verbose', is_flag=True, help='Show detailed output')
@click.pass_context @click.pass_context
def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verbose): def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verbose):
""" """
@@ -282,7 +322,7 @@ def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verb
# Progress callback # Progress callback
def progress(current, total, item): def progress(current, total, item):
if not ctx.obj.get('json'): if not ctx.obj.get("json"):
status = "" if item.status.value == "success" else "" status = "" if item.status.value == "success" else ""
click.echo(f"[{current}/{total}] {status} {item.input_path.name}") click.echo(f"[{current}/{total}] {status} {item.input_path.name}")
@@ -294,19 +334,18 @@ def batch_decode(ctx, images, output_dir, passphrase, pin, recursive, jobs, verb
output_dir=Path(output_dir) if output_dir else None, output_dir=Path(output_dir) if output_dir else None,
credentials=credentials, credentials=credentials,
recursive=recursive, recursive=recursive,
progress_callback=progress if not ctx.obj.get('json') else None, progress_callback=progress if not ctx.obj.get("json") else None,
) )
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(result.to_json()) click.echo(result.to_json())
else: else:
print_batch_result(result, verbose) print_batch_result(result, verbose)
@batch.command('check') @batch.command("check")
@click.argument('images', nargs=-1, required=True, type=click.Path(exists=True)) @click.argument("images", nargs=-1, required=True, type=click.Path(exists=True))
@click.option('-r', '--recursive', is_flag=True, @click.option("-r", "--recursive", is_flag=True, help="Search directories recursively")
help='Search directories recursively')
@click.pass_context @click.pass_context
def batch_check(ctx, images, recursive): def batch_check(ctx, images, recursive):
""" """
@@ -320,22 +359,22 @@ def batch_check(ctx, images, recursive):
""" """
results = batch_capacity_check(list(images), recursive) results = batch_capacity_check(list(images), recursive)
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(json.dumps(results, indent=2)) click.echo(json.dumps(results, indent=2))
else: else:
click.echo(f"{'Image':<40} {'Size':<12} {'Capacity':<12} {'Status'}") click.echo(f"{'Image':<40} {'Size':<12} {'Capacity':<12} {'Status'}")
click.echo("" * 80) click.echo("" * 80)
for item in results: for item in results:
if 'error' in item: if "error" in item:
click.echo(f"{Path(item['path']).name:<40} {'ERROR':<12} {'':<12} {item['error']}") click.echo(f"{Path(item['path']).name:<40} {'ERROR':<12} {'':<12} {item['error']}")
else: else:
name = Path(item['path']).name name = Path(item["path"]).name
if len(name) > 38: if len(name) > 38:
name = name[:35] + "..." name = name[:35] + "..."
status = "" if item['valid'] else "" status = "" if item["valid"] else ""
warnings = ", ".join(item.get('warnings', [])) warnings = ", ".join(item.get("warnings", []))
click.echo( click.echo(
f"{name:<40} " f"{name:<40} "
@@ -349,11 +388,16 @@ def batch_check(ctx, images, recursive):
# UTILITY COMMANDS # UTILITY COMMANDS
# ============================================================================= # =============================================================================
@cli.command() @cli.command()
@click.option('--words', default=DEFAULT_PASSPHRASE_WORDS, @click.option(
help=f'Number of words in passphrase (default: {DEFAULT_PASSPHRASE_WORDS})') "--words",
@click.option('--pin-length', default=DEFAULT_PIN_LENGTH, default=DEFAULT_PASSPHRASE_WORDS,
help=f'PIN length (default: {DEFAULT_PIN_LENGTH})') help=f"Number of words in passphrase (default: {DEFAULT_PASSPHRASE_WORDS})",
)
@click.option(
"--pin-length", default=DEFAULT_PIN_LENGTH, help=f"PIN length (default: {DEFAULT_PIN_LENGTH})"
)
@click.pass_context @click.pass_context
def generate(ctx, words, pin_length): def generate(ctx, words, pin_length):
""" """
@@ -368,24 +412,37 @@ def generate(ctx, words, pin_length):
import secrets import secrets
# Generate PIN # Generate PIN
pin = ''.join(str(secrets.randbelow(10)) for _ in range(pin_length)) pin = "".join(str(secrets.randbelow(10)) for _ in range(pin_length))
# Ensure PIN doesn't start with 0 # Ensure PIN doesn't start with 0
if pin[0] == '0': if pin[0] == "0":
pin = str(secrets.randbelow(9) + 1) + pin[1:] pin = str(secrets.randbelow(9) + 1) + pin[1:]
# Generate passphrase (would use BIP-39 wordlist) # Generate passphrase (would use BIP-39 wordlist)
# Placeholder - actual implementation uses constants.get_wordlist() # Placeholder - actual implementation uses constants.get_wordlist()
try: try:
from .constants import get_wordlist from .constants import get_wordlist
wordlist = get_wordlist() wordlist = get_wordlist()
phrase_words = [secrets.choice(wordlist) for _ in range(words)] phrase_words = [secrets.choice(wordlist) for _ in range(words)]
except (ImportError, FileNotFoundError): except (ImportError, FileNotFoundError):
# Fallback for testing # Fallback for testing
sample_words = ['alpha', 'bravo', 'charlie', 'delta', 'echo', 'foxtrot', sample_words = [
'golf', 'hotel', 'india', 'juliet', 'kilo', 'lima'] "alpha",
"bravo",
"charlie",
"delta",
"echo",
"foxtrot",
"golf",
"hotel",
"india",
"juliet",
"kilo",
"lima",
]
phrase_words = [secrets.choice(sample_words) for _ in range(words)] phrase_words = [secrets.choice(sample_words) for _ in range(words)]
passphrase = ' '.join(phrase_words) passphrase = " ".join(phrase_words)
result = { result = {
"passphrase": passphrase, "passphrase": passphrase,
@@ -394,7 +451,7 @@ def generate(ctx, words, pin_length):
"pin_length": pin_length, "pin_length": pin_length,
} }
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(json.dumps(result, indent=2)) click.echo(json.dumps(result, indent=2))
else: else:
click.echo(f"Passphrase: {passphrase}") click.echo(f"Passphrase: {passphrase}")
@@ -418,7 +475,7 @@ def info(ctx):
}, },
} }
if ctx.obj.get('json'): if ctx.obj.get("json"):
click.echo(json.dumps(info_data, indent=2)) click.echo(json.dumps(info_data, indent=2))
else: else:
click.echo(f"Stegasoo v{__version__}") click.echo(f"Stegasoo v{__version__}")
@@ -437,5 +494,5 @@ def main():
cli(obj={}) cli(obj={})
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@@ -12,6 +12,7 @@ from enum import IntEnum
# Optional LZ4 support (faster, slightly worse ratio) # Optional LZ4 support (faster, slightly worse ratio)
try: try:
import lz4.frame import lz4.frame
HAS_LZ4 = True HAS_LZ4 = True
except ImportError: except ImportError:
HAS_LZ4 = False HAS_LZ4 = False
@@ -19,13 +20,14 @@ except ImportError:
class CompressionAlgorithm(IntEnum): class CompressionAlgorithm(IntEnum):
"""Supported compression algorithms.""" """Supported compression algorithms."""
NONE = 0 NONE = 0
ZLIB = 1 ZLIB = 1
LZ4 = 2 LZ4 = 2
# Magic bytes for compressed payloads # Magic bytes for compressed payloads
COMPRESSION_MAGIC = b'\x00CMP' COMPRESSION_MAGIC = b"\x00CMP"
# Minimum size to bother compressing (small data often expands) # Minimum size to bother compressing (small data often expands)
MIN_COMPRESS_SIZE = 64 MIN_COMPRESS_SIZE = 64
@@ -36,6 +38,7 @@ ZLIB_LEVEL = 6
class CompressionError(Exception): class CompressionError(Exception):
"""Raised when compression/decompression fails.""" """Raised when compression/decompression fails."""
pass pass
@@ -77,7 +80,7 @@ def compress(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm
return _wrap_uncompressed(data) return _wrap_uncompressed(data)
# Build header: MAGIC + algorithm + original_size + compressed_data # Build header: MAGIC + algorithm + original_size + compressed_data
header = COMPRESSION_MAGIC + struct.pack('<BI', algorithm, len(data)) header = COMPRESSION_MAGIC + struct.pack("<BI", algorithm, len(data))
return header + compressed return header + compressed
@@ -101,7 +104,7 @@ def decompress(data: bytes) -> bytes:
# Parse header # Parse header
algorithm = CompressionAlgorithm(data[4]) algorithm = CompressionAlgorithm(data[4])
original_size = struct.unpack('<I', data[5:9])[0] original_size = struct.unpack("<I", data[5:9])[0]
compressed_data = data[9:] compressed_data = data[9:]
if algorithm == CompressionAlgorithm.NONE: if algorithm == CompressionAlgorithm.NONE:
@@ -125,16 +128,14 @@ def decompress(data: bytes) -> bytes:
# Verify size # Verify size
if len(result) != original_size: if len(result) != original_size:
raise CompressionError( raise CompressionError(f"Size mismatch: expected {original_size}, got {len(result)}")
f"Size mismatch: expected {original_size}, got {len(result)}"
)
return result return result
def _wrap_uncompressed(data: bytes) -> bytes: def _wrap_uncompressed(data: bytes) -> bytes:
"""Wrap uncompressed data with header for consistency.""" """Wrap uncompressed data with header for consistency."""
header = COMPRESSION_MAGIC + struct.pack('<BI', CompressionAlgorithm.NONE, len(data)) header = COMPRESSION_MAGIC + struct.pack("<BI", CompressionAlgorithm.NONE, len(data))
return header + data return header + data
@@ -150,7 +151,9 @@ def get_compression_ratio(original: bytes, compressed: bytes) -> float:
return len(compressed) / len(original) return len(compressed) / len(original)
def estimate_compressed_size(data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm.ZLIB) -> int: def estimate_compressed_size(
data: bytes, algorithm: CompressionAlgorithm = CompressionAlgorithm.ZLIB
) -> int:
""" """
Estimate compressed size without full compression. Estimate compressed size without full compression.
Uses sampling for large data. Uses sampling for large data.

View File

@@ -26,7 +26,7 @@ __version__ = "4.0.1"
# FILE FORMAT # FILE FORMAT
# ============================================================================ # ============================================================================
MAGIC_HEADER = b'\x89ST3' MAGIC_HEADER = b"\x89ST3"
# FORMAT VERSION HISTORY: # FORMAT VERSION HISTORY:
# Version 1-3: Date-dependent encryption (v3.0.x - v3.1.x) # Version 1-3: Date-dependent encryption (v3.0.x - v3.1.x)
@@ -119,11 +119,11 @@ QR_CROP_MIN_PADDING_PX = 10 # Minimum padding in pixels
# FILE TYPES # FILE TYPES
# ============================================================================ # ============================================================================
ALLOWED_IMAGE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp', 'gif'} ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "bmp", "gif"}
ALLOWED_KEY_EXTENSIONS = {'pem', 'key'} ALLOWED_KEY_EXTENSIONS = {"pem", "key"}
# Lossless image formats (safe for steganography) # Lossless image formats (safe for steganography)
LOSSLESS_FORMATS = {'PNG', 'BMP', 'TIFF'} LOSSLESS_FORMATS = {"PNG", "BMP", "TIFF"}
# Supported image formats for steganography # Supported image formats for steganography
SUPPORTED_IMAGE_FORMATS = LOSSLESS_FORMATS SUPPORTED_IMAGE_FORMATS = LOSSLESS_FORMATS
@@ -132,7 +132,7 @@ SUPPORTED_IMAGE_FORMATS = LOSSLESS_FORMATS
# DAYS (kept for organizational/UI purposes, not crypto) # DAYS (kept for organizational/UI purposes, not crypto)
# ============================================================================ # ============================================================================
DAY_NAMES = ('Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday') DAY_NAMES = ("Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday")
# ============================================================================ # ============================================================================
# COMPRESSION # COMPRESSION
@@ -145,7 +145,7 @@ MIN_COMPRESS_SIZE = 64
ZLIB_COMPRESSION_LEVEL = 6 ZLIB_COMPRESSION_LEVEL = 6
# Compression header magic bytes # Compression header magic bytes
COMPRESSION_MAGIC = b'\x00CMP' COMPRESSION_MAGIC = b"\x00CMP"
# ============================================================================ # ============================================================================
# BATCH PROCESSING # BATCH PROCESSING
@@ -164,6 +164,7 @@ BATCH_OUTPUT_SUFFIX = "_encoded"
# DATA FILES # DATA FILES
# ============================================================================ # ============================================================================
def get_data_dir() -> Path: def get_data_dir() -> Path:
"""Get the data directory path.""" """Get the data directory path."""
# Check multiple locations # Check multiple locations
@@ -172,12 +173,12 @@ def get_data_dir() -> Path:
# .parent.parent = src/ # .parent.parent = src/
# .parent.parent.parent = project root (where data/ lives) # .parent.parent.parent = project root (where data/ lives)
candidates = [ candidates = [
Path(__file__).parent.parent.parent / 'data', # Development: src/stegasoo -> project root Path(__file__).parent.parent.parent / "data", # Development: src/stegasoo -> project root
Path(__file__).parent / 'data', # Installed package Path(__file__).parent / "data", # Installed package
Path('/app/data'), # Docker Path("/app/data"), # Docker
Path.cwd() / 'data', # Current directory Path.cwd() / "data", # Current directory
Path.cwd().parent / 'data', # One level up from cwd Path.cwd().parent / "data", # One level up from cwd
Path.cwd().parent.parent / 'data', # Two levels up from cwd Path.cwd().parent.parent / "data", # Two levels up from cwd
] ]
for path in candidates: for path in candidates:
@@ -190,7 +191,7 @@ def get_data_dir() -> Path:
def get_bip39_words() -> list[str]: def get_bip39_words() -> list[str]:
"""Load BIP-39 wordlist.""" """Load BIP-39 wordlist."""
wordlist_path = get_data_dir() / 'bip39-words.txt' wordlist_path = get_data_dir() / "bip39-words.txt"
if not wordlist_path.exists(): if not wordlist_path.exists():
raise FileNotFoundError( raise FileNotFoundError(
@@ -219,12 +220,12 @@ def get_wordlist() -> list[str]:
# ============================================================================= # =============================================================================
# Embedding modes # Embedding modes
EMBED_MODE_LSB = 'lsb' # Spatial LSB embedding (default, original mode) EMBED_MODE_LSB = "lsb" # Spatial LSB embedding (default, original mode)
EMBED_MODE_DCT = 'dct' # DCT domain embedding (new in v3.0) EMBED_MODE_DCT = "dct" # DCT domain embedding (new in v3.0)
EMBED_MODE_AUTO = 'auto' # Auto-detect on decode EMBED_MODE_AUTO = "auto" # Auto-detect on decode
# DCT-specific constants # DCT-specific constants
DCT_MAGIC_HEADER = b'\x89DCT' # Magic header for DCT mode DCT_MAGIC_HEADER = b"\x89DCT" # Magic header for DCT mode
DCT_FORMAT_VERSION = 1 DCT_FORMAT_VERSION = 1
DCT_STEP_SIZE = 8 # QIM quantization step DCT_STEP_SIZE = 8 # QIM quantization step
@@ -247,13 +248,13 @@ def detect_stego_mode(encrypted_data: bytes) -> str:
'lsb' or 'dct' or 'unknown' 'lsb' or 'dct' or 'unknown'
""" """
if len(encrypted_data) < 4: if len(encrypted_data) < 4:
return 'unknown' return "unknown"
header = encrypted_data[:4] header = encrypted_data[:4]
if header == b'\x89ST3': if header == b"\x89ST3":
return EMBED_MODE_LSB return EMBED_MODE_LSB
elif header == b'\x89DCT': elif header == b"\x89DCT":
return EMBED_MODE_DCT return EMBED_MODE_DCT
else: else:
return 'unknown' return "unknown"

View File

@@ -44,6 +44,7 @@ from .models import DecodeResult, FilePayload
# Check for Argon2 availability # Check for Argon2 availability
try: try:
from argon2.low_level import Type, hash_secret_raw from argon2.low_level import Type, hash_secret_raw
HAS_ARGON2 = True HAS_ARGON2 = True
except ImportError: except ImportError:
HAS_ARGON2 = False HAS_ARGON2 = False
@@ -79,15 +80,17 @@ def _resolve_channel_key(channel_key: str | bool | None) -> bytes | None:
# Auto-detect from environment/config # Auto-detect from environment/config
if channel_key is None or channel_key == CHANNEL_KEY_AUTO: if channel_key is None or channel_key == CHANNEL_KEY_AUTO:
from .channel import get_channel_key_hash from .channel import get_channel_key_hash
return get_channel_key_hash() return get_channel_key_hash()
# Explicit key provided - validate and hash it # Explicit key provided - validate and hash it
if isinstance(channel_key, str): if isinstance(channel_key, str):
from .channel import format_channel_key, validate_channel_key from .channel import format_channel_key, validate_channel_key
if not validate_channel_key(channel_key): if not validate_channel_key(channel_key):
raise ValueError(f"Invalid channel key format: {channel_key}") raise ValueError(f"Invalid channel key format: {channel_key}")
formatted = format_channel_key(channel_key) formatted = format_channel_key(channel_key)
return hashlib.sha256(formatted.encode('utf-8')).digest() return hashlib.sha256(formatted.encode("utf-8")).digest()
raise ValueError(f"Invalid channel_key type: {type(channel_key)}") raise ValueError(f"Invalid channel_key type: {type(channel_key)}")
@@ -96,6 +99,7 @@ def _resolve_channel_key(channel_key: str | bool | None) -> bytes | None:
# CORE CRYPTO FUNCTIONS # CORE CRYPTO FUNCTIONS
# ============================================================================= # =============================================================================
def hash_photo(image_data: bytes) -> bytes: def hash_photo(image_data: bytes) -> bytes:
""" """
Compute deterministic hash of photo pixel content. Compute deterministic hash of photo pixel content.
@@ -109,7 +113,7 @@ def hash_photo(image_data: bytes) -> bytes:
Returns: Returns:
32-byte SHA-256 hash 32-byte SHA-256 hash
""" """
img: Image.Image = Image.open(io.BytesIO(image_data)).convert('RGB') img: Image.Image = Image.open(io.BytesIO(image_data)).convert("RGB")
pixels = img.tobytes() pixels = img.tobytes()
# Double-hash with prefix for additional mixing # Double-hash with prefix for additional mixing
@@ -163,12 +167,7 @@ def derive_hybrid_key(
channel_hash = _resolve_channel_key(channel_key) channel_hash = _resolve_channel_key(channel_key)
# Build key material # Build key material
key_material = ( key_material = photo_hash + passphrase.lower().encode() + pin.encode() + salt
photo_hash +
passphrase.lower().encode() +
pin.encode() +
salt
)
# Add RSA key hash if provided # Add RSA key hash if provided
if rsa_key_data: if rsa_key_data:
@@ -186,7 +185,7 @@ def derive_hybrid_key(
memory_cost=ARGON2_MEMORY_COST, memory_cost=ARGON2_MEMORY_COST,
parallelism=ARGON2_PARALLELISM, parallelism=ARGON2_PARALLELISM,
hash_len=32, hash_len=32,
type=Type.ID type=Type.ID,
) )
else: else:
kdf = PBKDF2HMAC( kdf = PBKDF2HMAC(
@@ -194,7 +193,7 @@ def derive_hybrid_key(
length=32, length=32,
salt=salt, salt=salt,
iterations=PBKDF2_ITERATIONS, iterations=PBKDF2_ITERATIONS,
backend=default_backend() backend=default_backend(),
) )
key = kdf.derive(key_material) key = kdf.derive(key_material)
@@ -232,11 +231,7 @@ def derive_pixel_key(
# Resolve channel key # Resolve channel key
channel_hash = _resolve_channel_key(channel_key) channel_hash = _resolve_channel_key(channel_key)
material = ( material = photo_hash + passphrase.lower().encode() + pin.encode()
photo_hash +
passphrase.lower().encode() +
pin.encode()
)
if rsa_key_data: if rsa_key_data:
material += hashlib.sha256(rsa_key_data).digest() material += hashlib.sha256(rsa_key_data).digest()
@@ -268,31 +263,31 @@ def _pack_payload(
""" """
if isinstance(content, str): if isinstance(content, str):
# Text message # Text message
data = content.encode('utf-8') data = content.encode("utf-8")
return bytes([PAYLOAD_TEXT]) + data, PAYLOAD_TEXT return bytes([PAYLOAD_TEXT]) + data, PAYLOAD_TEXT
elif isinstance(content, FilePayload): elif isinstance(content, FilePayload):
# File with metadata # File with metadata
filename = content.filename[:MAX_FILENAME_LENGTH].encode('utf-8') filename = content.filename[:MAX_FILENAME_LENGTH].encode("utf-8")
mime = (content.mime_type or '')[:100].encode('utf-8') mime = (content.mime_type or "")[:100].encode("utf-8")
packed = ( packed = (
bytes([PAYLOAD_FILE]) + bytes([PAYLOAD_FILE])
struct.pack('>H', len(filename)) + + struct.pack(">H", len(filename))
filename + + filename
struct.pack('>H', len(mime)) + + struct.pack(">H", len(mime))
mime + + mime
content.data + content.data
) )
return packed, PAYLOAD_FILE return packed, PAYLOAD_FILE
else: else:
# Raw bytes - treat as file with no name # Raw bytes - treat as file with no name
packed = ( packed = (
bytes([PAYLOAD_FILE]) + bytes([PAYLOAD_FILE])
struct.pack('>H', 0) + # No filename + struct.pack(">H", 0) # No filename
struct.pack('>H', 0) + # No mime + struct.pack(">H", 0) # No mime
content + content
) )
return packed, PAYLOAD_FILE return packed, PAYLOAD_FILE
@@ -314,42 +309,39 @@ def _unpack_payload(data: bytes) -> DecodeResult:
if payload_type == PAYLOAD_TEXT: if payload_type == PAYLOAD_TEXT:
# Text message # Text message
text = data[1:].decode('utf-8') text = data[1:].decode("utf-8")
return DecodeResult(payload_type='text', message=text) return DecodeResult(payload_type="text", message=text)
elif payload_type == PAYLOAD_FILE: elif payload_type == PAYLOAD_FILE:
# File with metadata # File with metadata
offset = 1 offset = 1
# Read filename # Read filename
filename_len = struct.unpack('>H', data[offset:offset+2])[0] filename_len = struct.unpack(">H", data[offset : offset + 2])[0]
offset += 2 offset += 2
filename = data[offset:offset+filename_len].decode('utf-8') if filename_len else None filename = data[offset : offset + filename_len].decode("utf-8") if filename_len else None
offset += filename_len offset += filename_len
# Read mime type # Read mime type
mime_len = struct.unpack('>H', data[offset:offset+2])[0] mime_len = struct.unpack(">H", data[offset : offset + 2])[0]
offset += 2 offset += 2
mime_type = data[offset:offset+mime_len].decode('utf-8') if mime_len else None mime_type = data[offset : offset + mime_len].decode("utf-8") if mime_len else None
offset += mime_len offset += mime_len
# Rest is file data # Rest is file data
file_data = data[offset:] file_data = data[offset:]
return DecodeResult( return DecodeResult(
payload_type='file', payload_type="file", file_data=file_data, filename=filename, mime_type=mime_type
file_data=file_data,
filename=filename,
mime_type=mime_type
) )
else: else:
# Unknown type - try to decode as text (backward compatibility) # Unknown type - try to decode as text (backward compatibility)
try: try:
text = data.decode('utf-8') text = data.decode("utf-8")
return DecodeResult(payload_type='text', message=text) return DecodeResult(payload_type="text", message=text)
except UnicodeDecodeError: except UnicodeDecodeError:
return DecodeResult(payload_type='file', file_data=data) return DecodeResult(payload_type="file", file_data=data)
# ============================================================================= # =============================================================================
@@ -415,7 +407,7 @@ def encrypt_message(
padding_len = secrets.randbelow(256) + 64 padding_len = secrets.randbelow(256) + 64
padded_len = ((len(packed_payload) + padding_len + 255) // 256) * 256 padded_len = ((len(packed_payload) + padding_len + 255) // 256) * 256
padding_needed = padded_len - len(packed_payload) padding_needed = padded_len - len(packed_payload)
padding = secrets.token_bytes(padding_needed - 4) + struct.pack('>I', len(packed_payload)) padding = secrets.token_bytes(padding_needed - 4) + struct.pack(">I", len(packed_payload))
padded_message = packed_payload + padding padded_message = packed_payload + padding
# Build header for AAD # Build header for AAD
@@ -428,13 +420,7 @@ def encrypt_message(
ciphertext = encryptor.update(padded_message) + encryptor.finalize() ciphertext = encryptor.update(padded_message) + encryptor.finalize()
# v4.0.0: Header with flags byte # v4.0.0: Header with flags byte
return ( return header + salt + iv + encryptor.tag + ciphertext
header +
salt +
iv +
encryptor.tag +
ciphertext
)
except Exception as e: except Exception as e:
raise EncryptionError(f"Encryption failed: {e}") from e raise EncryptionError(f"Encryption failed: {e}") from e
@@ -473,13 +459,13 @@ def parse_header(encrypted_data: bytes) -> dict | None:
ciphertext = encrypted_data[offset:] ciphertext = encrypted_data[offset:]
return { return {
'version': version, "version": version,
'flags': flags, "flags": flags,
'has_channel_key': bool(flags & FLAG_CHANNEL_KEY), "has_channel_key": bool(flags & FLAG_CHANNEL_KEY),
'salt': salt, "salt": salt,
'iv': iv, "iv": iv,
'tag': tag, "tag": tag,
'ciphertext': ciphertext "ciphertext": ciphertext,
} }
except Exception: except Exception:
return None return None
@@ -518,26 +504,24 @@ def decrypt_message(
# Check for channel key mismatch and provide helpful error # Check for channel key mismatch and provide helpful error
channel_hash = _resolve_channel_key(channel_key) channel_hash = _resolve_channel_key(channel_key)
has_configured_key = channel_hash is not None has_configured_key = channel_hash is not None
message_has_key = header['has_channel_key'] message_has_key = header["has_channel_key"]
try: try:
key = derive_hybrid_key( key = derive_hybrid_key(
photo_data, passphrase, header['salt'], pin, rsa_key_data, channel_key photo_data, passphrase, header["salt"], pin, rsa_key_data, channel_key
) )
# Reconstruct header for AAD verification # Reconstruct header for AAD verification
aad_header = MAGIC_HEADER + bytes([FORMAT_VERSION, header['flags']]) aad_header = MAGIC_HEADER + bytes([FORMAT_VERSION, header["flags"]])
cipher = Cipher( cipher = Cipher(
algorithms.AES(key), algorithms.AES(key), modes.GCM(header["iv"], header["tag"]), backend=default_backend()
modes.GCM(header['iv'], header['tag']),
backend=default_backend()
) )
decryptor = cipher.decryptor() decryptor = cipher.decryptor()
decryptor.authenticate_additional_data(aad_header) decryptor.authenticate_additional_data(aad_header)
padded_plaintext = decryptor.update(header['ciphertext']) + decryptor.finalize() padded_plaintext = decryptor.update(header["ciphertext"]) + decryptor.finalize()
original_length = struct.unpack('>I', padded_plaintext[-4:])[0] original_length = struct.unpack(">I", padded_plaintext[-4:])[0]
payload_data = padded_plaintext[:original_length] payload_data = padded_plaintext[:original_length]
result = _unpack_payload(payload_data) result = _unpack_payload(payload_data)
@@ -596,7 +580,7 @@ def decrypt_message_text(
if result.file_data: if result.file_data:
# Try to decode as text # Try to decode as text
try: try:
return result.file_data.decode('utf-8') return result.file_data.decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
raise DecryptionError( raise DecryptionError(
f"Content is a binary file ({result.filename or 'unnamed'}), not text" f"Content is a binary file ({result.filename or 'unnamed'}), not text"
@@ -615,6 +599,7 @@ def has_argon2() -> bool:
# CHANNEL KEY UTILITIES (exposed for convenience) # CHANNEL KEY UTILITIES (exposed for convenience)
# ============================================================================= # =============================================================================
def get_active_channel_key() -> str | None: def get_active_channel_key() -> str | None:
""" """
Get the currently configured channel key (if any). Get the currently configured channel key (if any).
@@ -623,6 +608,7 @@ def get_active_channel_key() -> str | None:
Formatted channel key string, or None if not configured Formatted channel key string, or None if not configured
""" """
from .channel import get_channel_key from .channel import get_channel_key
return get_channel_key() return get_channel_key()
@@ -637,4 +623,5 @@ def get_channel_fingerprint(key: str | None = None) -> str | None:
Masked key like "ABCD-••••-••••-••••-••••-••••-••••-3456" or None Masked key like "ABCD-••••-••••-••••-••••-••••-••••-3456" or None
""" """
from .channel import get_channel_fingerprint as _get_fingerprint from .channel import get_channel_fingerprint as _get_fingerprint
return _get_fingerprint(key) return _get_fingerprint(key)

View File

@@ -28,10 +28,12 @@ from PIL import Image
# Prefer scipy.fft (newer, more stable) over scipy.fftpack # Prefer scipy.fft (newer, more stable) over scipy.fftpack
try: try:
from scipy.fft import dct, idct from scipy.fft import dct, idct
HAS_SCIPY = True HAS_SCIPY = True
except ImportError: except ImportError:
try: try:
from scipy.fftpack import dct, idct from scipy.fftpack import dct, idct
HAS_SCIPY = True HAS_SCIPY = True
except ImportError: except ImportError:
HAS_SCIPY = False HAS_SCIPY = False
@@ -41,6 +43,7 @@ except ImportError:
# Check for jpegio availability (for proper JPEG mode) # Check for jpegio availability (for proper JPEG mode)
try: try:
import jpegio as jio import jpegio as jio
HAS_JPEGIO = True HAS_JPEGIO = True
except ImportError: except ImportError:
HAS_JPEGIO = False HAS_JPEGIO = False
@@ -53,19 +56,49 @@ except ImportError:
BLOCK_SIZE = 8 BLOCK_SIZE = 8
EMBED_POSITIONS = [ EMBED_POSITIONS = [
(0, 1), (1, 0), (2, 0), (1, 1), (0, 2), (0, 3), (1, 2), (2, 1), (3, 0), (0, 1),
(4, 0), (3, 1), (2, 2), (1, 3), (0, 4), (0, 5), (1, 4), (2, 3), (3, 2), (1, 0),
(4, 1), (5, 0), (5, 1), (4, 2), (3, 3), (2, 4), (1, 5), (0, 6), (0, 7), (2, 0),
(1, 6), (2, 5), (3, 4), (4, 3), (5, 2), (6, 1), (7, 0), (1, 1),
(0, 2),
(0, 3),
(1, 2),
(2, 1),
(3, 0),
(4, 0),
(3, 1),
(2, 2),
(1, 3),
(0, 4),
(0, 5),
(1, 4),
(2, 3),
(3, 2),
(4, 1),
(5, 0),
(5, 1),
(4, 2),
(3, 3),
(2, 4),
(1, 5),
(0, 6),
(0, 7),
(1, 6),
(2, 5),
(3, 4),
(4, 3),
(5, 2),
(6, 1),
(7, 0),
] ]
DEFAULT_EMBED_POSITIONS = EMBED_POSITIONS[4:20] DEFAULT_EMBED_POSITIONS = EMBED_POSITIONS[4:20]
QUANT_STEP = 25 QUANT_STEP = 25
DCT_MAGIC = b'DCTS' DCT_MAGIC = b"DCTS"
HEADER_SIZE = 10 HEADER_SIZE = 10
OUTPUT_FORMAT_PNG = 'png' OUTPUT_FORMAT_PNG = "png"
OUTPUT_FORMAT_JPEG = 'jpeg' OUTPUT_FORMAT_JPEG = "jpeg"
JPEG_OUTPUT_QUALITY = 95 JPEG_OUTPUT_QUALITY = 95
JPEGIO_MAGIC = b'JPGS' JPEGIO_MAGIC = b"JPGS"
JPEGIO_MIN_COEF_MAGNITUDE = 2 JPEGIO_MIN_COEF_MAGNITUDE = 2
JPEGIO_EMBED_CHANNEL = 0 JPEGIO_EMBED_CHANNEL = 0
FLAG_COLOR_MODE = 0x01 FLAG_COLOR_MODE = 0x01
@@ -83,9 +116,10 @@ JPEGIO_MAX_QUANT_VALUE_THRESHOLD = 1 # If all quant values <= this, normalize
# DATA CLASSES # DATA CLASSES
# ============================================================================ # ============================================================================
class DCTOutputFormat(Enum): class DCTOutputFormat(Enum):
PNG = 'png' PNG = "png"
JPEG = 'jpeg' JPEG = "jpeg"
@dataclass @dataclass
@@ -99,7 +133,7 @@ class DCTEmbedStats:
image_height: int image_height: int
output_format: str output_format: str
jpeg_native: bool = False jpeg_native: bool = False
color_mode: str = 'grayscale' color_mode: str = "grayscale"
@dataclass @dataclass
@@ -119,11 +153,10 @@ class DCTCapacityInfo:
# AVAILABILITY CHECKS # AVAILABILITY CHECKS
# ============================================================================ # ============================================================================
def _check_scipy(): def _check_scipy():
if not HAS_SCIPY: if not HAS_SCIPY:
raise ImportError( raise ImportError("DCT steganography requires scipy. Install with: pip install scipy")
"DCT steganography requires scipy. Install with: pip install scipy"
)
def has_dct_support() -> bool: def has_dct_support() -> bool:
@@ -139,25 +172,26 @@ def has_jpegio_support() -> bool:
# These create fresh arrays to avoid scipy memory corruption issues # These create fresh arrays to avoid scipy memory corruption issues
# ============================================================================ # ============================================================================
def _safe_dct2(block: np.ndarray) -> np.ndarray: def _safe_dct2(block: np.ndarray) -> np.ndarray:
""" """
Apply 2D DCT with memory isolation. Apply 2D DCT with memory isolation.
Creates a completely fresh array to avoid heap corruption. Creates a completely fresh array to avoid heap corruption.
""" """
# Create a brand new array (not a view) # Create a brand new array (not a view)
safe_block = np.array(block, dtype=np.float64, copy=True, order='C') safe_block = np.array(block, dtype=np.float64, copy=True, order="C")
# First DCT on columns (transpose -> DCT rows -> transpose back) # First DCT on columns (transpose -> DCT rows -> transpose back)
temp = np.zeros_like(safe_block, dtype=np.float64, order='C') temp = np.zeros_like(safe_block, dtype=np.float64, order="C")
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
col = np.array(safe_block[:, i], dtype=np.float64, copy=True) col = np.array(safe_block[:, i], dtype=np.float64, copy=True)
temp[:, i] = dct(col, norm='ortho') temp[:, i] = dct(col, norm="ortho")
# Second DCT on rows # Second DCT on rows
result = np.zeros_like(temp, dtype=np.float64, order='C') result = np.zeros_like(temp, dtype=np.float64, order="C")
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
row = np.array(temp[i, :], dtype=np.float64, copy=True) row = np.array(temp[i, :], dtype=np.float64, copy=True)
result[i, :] = dct(row, norm='ortho') result[i, :] = dct(row, norm="ortho")
return result return result
@@ -168,19 +202,19 @@ def _safe_idct2(block: np.ndarray) -> np.ndarray:
Creates a completely fresh array to avoid heap corruption. Creates a completely fresh array to avoid heap corruption.
""" """
# Create a brand new array (not a view) # Create a brand new array (not a view)
safe_block = np.array(block, dtype=np.float64, copy=True, order='C') safe_block = np.array(block, dtype=np.float64, copy=True, order="C")
# First IDCT on rows # First IDCT on rows
temp = np.zeros_like(safe_block, dtype=np.float64, order='C') temp = np.zeros_like(safe_block, dtype=np.float64, order="C")
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
row = np.array(safe_block[i, :], dtype=np.float64, copy=True) row = np.array(safe_block[i, :], dtype=np.float64, copy=True)
temp[i, :] = idct(row, norm='ortho') temp[i, :] = idct(row, norm="ortho")
# Second IDCT on columns # Second IDCT on columns
result = np.zeros_like(temp, dtype=np.float64, order='C') result = np.zeros_like(temp, dtype=np.float64, order="C")
for i in range(BLOCK_SIZE): for i in range(BLOCK_SIZE):
col = np.array(temp[:, i], dtype=np.float64, copy=True) col = np.array(temp[:, i], dtype=np.float64, copy=True)
result[:, i] = idct(col, norm='ortho') result[:, i] = idct(col, norm="ortho")
return result return result
@@ -189,20 +223,21 @@ def _safe_idct2(block: np.ndarray) -> np.ndarray:
# IMAGE PROCESSING HELPERS # IMAGE PROCESSING HELPERS
# ============================================================================ # ============================================================================
def _to_grayscale(image_data: bytes) -> np.ndarray: def _to_grayscale(image_data: bytes) -> np.ndarray:
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
gray = img.convert('L') gray = img.convert("L")
return np.array(gray, dtype=np.float64, copy=True, order='C') return np.array(gray, dtype=np.float64, copy=True, order="C")
def _extract_y_channel(image_data: bytes) -> np.ndarray: def _extract_y_channel(image_data: bytes) -> np.ndarray:
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
if img.mode != 'RGB': if img.mode != "RGB":
img = img.convert('RGB') img = img.convert("RGB")
rgb = np.array(img, dtype=np.float64, copy=True, order='C') rgb = np.array(img, dtype=np.float64, copy=True, order="C")
Y = 0.299 * rgb[:, :, 0] + 0.587 * rgb[:, :, 1] + 0.114 * rgb[:, :, 2] Y = 0.299 * rgb[:, :, 0] + 0.587 * rgb[:, :, 1] + 0.114 * rgb[:, :, 2]
return np.array(Y, dtype=np.float64, copy=True, order='C') return np.array(Y, dtype=np.float64, copy=True, order="C")
def _pad_to_blocks(image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]: def _pad_to_blocks(image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
@@ -211,9 +246,9 @@ def _pad_to_blocks(image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
new_w = ((w + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE new_w = ((w + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
if new_h == h and new_w == w: if new_h == h and new_w == w:
return np.array(image, dtype=np.float64, copy=True, order='C'), (h, w) return np.array(image, dtype=np.float64, copy=True, order="C"), (h, w)
padded = np.zeros((new_h, new_w), dtype=np.float64, order='C') padded = np.zeros((new_h, new_w), dtype=np.float64, order="C")
padded[:h, :w] = image padded[:h, :w] = image
# Simple edge replication for padding # Simple edge replication for padding
@@ -231,7 +266,7 @@ def _pad_to_blocks(image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
def _unpad_image(image: np.ndarray, original_size: tuple[int, int]) -> np.ndarray: def _unpad_image(image: np.ndarray, original_size: tuple[int, int]) -> np.ndarray:
h, w = original_size h, w = original_size
return np.array(image[:h, :w], dtype=np.float64, copy=True, order='C') return np.array(image[:h, :w], dtype=np.float64, copy=True, order="C")
def _embed_bit_in_coeff(coef: float, bit: int, quant_step: int = QUANT_STEP) -> float: def _embed_bit_in_coeff(coef: float, bit: int, quant_step: int = QUANT_STEP) -> float:
@@ -251,7 +286,7 @@ def _extract_bit_from_coeff(coef: float, quant_step: int = QUANT_STEP) -> int:
def _generate_block_order(num_blocks: int, seed: bytes) -> list: def _generate_block_order(num_blocks: int, seed: bytes) -> list:
hash_bytes = hashlib.sha256(seed).digest() hash_bytes = hashlib.sha256(seed).digest()
rng = np.random.RandomState(int.from_bytes(hash_bytes[:4], 'big')) rng = np.random.RandomState(int.from_bytes(hash_bytes[:4], "big"))
order = list(range(num_blocks)) order = list(range(num_blocks))
rng.shuffle(order) rng.shuffle(order)
return order return order
@@ -259,25 +294,23 @@ def _generate_block_order(num_blocks: int, seed: bytes) -> list:
def _save_stego_image(image: np.ndarray, output_format: str = OUTPUT_FORMAT_PNG) -> bytes: def _save_stego_image(image: np.ndarray, output_format: str = OUTPUT_FORMAT_PNG) -> bytes:
clipped = np.clip(image, 0, 255).astype(np.uint8) clipped = np.clip(image, 0, 255).astype(np.uint8)
img = Image.fromarray(clipped, mode='L') img = Image.fromarray(clipped, mode="L")
buffer = io.BytesIO() buffer = io.BytesIO()
if output_format == OUTPUT_FORMAT_JPEG: if output_format == OUTPUT_FORMAT_JPEG:
img.save(buffer, format='JPEG', quality=JPEG_OUTPUT_QUALITY, img.save(buffer, format="JPEG", quality=JPEG_OUTPUT_QUALITY, subsampling=0, optimize=True)
subsampling=0, optimize=True)
else: else:
img.save(buffer, format='PNG', optimize=True) img.save(buffer, format="PNG", optimize=True)
return buffer.getvalue() return buffer.getvalue()
def _save_color_image(rgb_array: np.ndarray, output_format: str = OUTPUT_FORMAT_PNG) -> bytes: def _save_color_image(rgb_array: np.ndarray, output_format: str = OUTPUT_FORMAT_PNG) -> bytes:
clipped = np.clip(rgb_array, 0, 255).astype(np.uint8) clipped = np.clip(rgb_array, 0, 255).astype(np.uint8)
img = Image.fromarray(clipped, mode='RGB') img = Image.fromarray(clipped, mode="RGB")
buffer = io.BytesIO() buffer = io.BytesIO()
if output_format == OUTPUT_FORMAT_JPEG: if output_format == OUTPUT_FORMAT_JPEG:
img.save(buffer, format='JPEG', quality=JPEG_OUTPUT_QUALITY, img.save(buffer, format="JPEG", quality=JPEG_OUTPUT_QUALITY, subsampling=0, optimize=True)
subsampling=0, optimize=True)
else: else:
img.save(buffer, format='PNG', optimize=True) img.save(buffer, format="PNG", optimize=True)
return buffer.getvalue() return buffer.getvalue()
@@ -286,9 +319,13 @@ def _rgb_to_ycbcr(rgb: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
G = rgb[:, :, 1].astype(np.float64) G = rgb[:, :, 1].astype(np.float64)
B = rgb[:, :, 2].astype(np.float64) B = rgb[:, :, 2].astype(np.float64)
Y = np.array(0.299 * R + 0.587 * G + 0.114 * B, dtype=np.float64, copy=True, order='C') Y = np.array(0.299 * R + 0.587 * G + 0.114 * B, dtype=np.float64, copy=True, order="C")
Cb = np.array(128 - 0.168736 * R - 0.331264 * G + 0.5 * B, dtype=np.float64, copy=True, order='C') Cb = np.array(
Cr = np.array(128 + 0.5 * R - 0.418688 * G - 0.081312 * B, dtype=np.float64, copy=True, order='C') 128 - 0.168736 * R - 0.331264 * G + 0.5 * B, dtype=np.float64, copy=True, order="C"
)
Cr = np.array(
128 + 0.5 * R - 0.418688 * G - 0.081312 * B, dtype=np.float64, copy=True, order="C"
)
return Y, Cb, Cr return Y, Cb, Cr
@@ -298,7 +335,7 @@ def _ycbcr_to_rgb(Y: np.ndarray, Cb: np.ndarray, Cr: np.ndarray) -> np.ndarray:
G = Y - 0.344136 * (Cb - 128) - 0.714136 * (Cr - 128) G = Y - 0.344136 * (Cb - 128) - 0.714136 * (Cr - 128)
B = Y + 1.772 * (Cb - 128) B = Y + 1.772 * (Cb - 128)
rgb = np.zeros((Y.shape[0], Y.shape[1], 3), dtype=np.float64, order='C') rgb = np.zeros((Y.shape[0], Y.shape[1], 3), dtype=np.float64, order="C")
rgb[:, :, 0] = R rgb[:, :, 0] = R
rgb[:, :, 1] = G rgb[:, :, 1] = G
rgb[:, :, 2] = B rgb[:, :, 2] = B
@@ -306,19 +343,21 @@ def _ycbcr_to_rgb(Y: np.ndarray, Cb: np.ndarray, Cr: np.ndarray) -> np.ndarray:
def _create_header(data_length: int, flags: int = 0) -> bytes: def _create_header(data_length: int, flags: int = 0) -> bytes:
return struct.pack('>4sBBI', DCT_MAGIC, 1, flags, data_length) return struct.pack(">4sBBI", DCT_MAGIC, 1, flags, data_length)
def _parse_header(header_bits: list) -> tuple[int, int, int]: def _parse_header(header_bits: list) -> tuple[int, int, int]:
if len(header_bits) < HEADER_SIZE * 8: if len(header_bits) < HEADER_SIZE * 8:
raise ValueError("Insufficient header data") raise ValueError("Insufficient header data")
header_bytes = bytes([ header_bytes = bytes(
[
sum(header_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8)) sum(header_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8))
for i in range(HEADER_SIZE) for i in range(HEADER_SIZE)
]) ]
)
magic, version, flags, length = struct.unpack('>4sBBI', header_bytes) magic, version, flags, length = struct.unpack(">4sBBI", header_bytes)
if magic != DCT_MAGIC: if magic != DCT_MAGIC:
raise ValueError("Invalid DCT stego magic bytes") raise ValueError("Invalid DCT stego magic bytes")
@@ -330,9 +369,11 @@ def _parse_header(header_bits: list) -> tuple[int, int, int]:
# JPEGIO HELPERS # JPEGIO HELPERS
# ============================================================================ # ============================================================================
def _jpegio_bytes_to_file(data: bytes, suffix: str = '.jpg') -> str:
def _jpegio_bytes_to_file(data: bytes, suffix: str = ".jpg") -> str:
import os import os
import tempfile import tempfile
fd, path = tempfile.mkstemp(suffix=suffix) fd, path = tempfile.mkstemp(suffix=suffix)
try: try:
os.write(fd, data) os.write(fd, data)
@@ -355,20 +396,20 @@ def _jpegio_get_usable_positions(coef_array: np.ndarray) -> list:
def _jpegio_generate_order(num_positions: int, seed: bytes) -> list: def _jpegio_generate_order(num_positions: int, seed: bytes) -> list:
hash_bytes = hashlib.sha256(seed + b"jpeg_coef_order").digest() hash_bytes = hashlib.sha256(seed + b"jpeg_coef_order").digest()
rng = np.random.RandomState(int.from_bytes(hash_bytes[:4], 'big')) rng = np.random.RandomState(int.from_bytes(hash_bytes[:4], "big"))
order = list(range(num_positions)) order = list(range(num_positions))
rng.shuffle(order) rng.shuffle(order)
return order return order
def _jpegio_create_header(data_length: int, flags: int = 0) -> bytes: def _jpegio_create_header(data_length: int, flags: int = 0) -> bytes:
return struct.pack('>4sBBI', JPEGIO_MAGIC, 1, flags, data_length) return struct.pack(">4sBBI", JPEGIO_MAGIC, 1, flags, data_length)
def _jpegio_parse_header(header_bytes: bytes) -> tuple[int, int, int]: def _jpegio_parse_header(header_bytes: bytes) -> tuple[int, int, int]:
if len(header_bytes) < HEADER_SIZE: if len(header_bytes) < HEADER_SIZE:
raise ValueError("Insufficient header data") raise ValueError("Insufficient header data")
magic, version, flags, length = struct.unpack('>4sBBI', header_bytes[:HEADER_SIZE]) magic, version, flags, length = struct.unpack(">4sBBI", header_bytes[:HEADER_SIZE])
if magic != JPEGIO_MAGIC: if magic != JPEGIO_MAGIC:
raise ValueError(f"Invalid JPEG stego magic: {magic}") raise ValueError(f"Invalid JPEG stego magic: {magic}")
return version, flags, length return version, flags, length
@@ -378,6 +419,7 @@ def _jpegio_parse_header(header_bytes: bytes) -> tuple[int, int, int]:
# PUBLIC API # PUBLIC API
# ============================================================================ # ============================================================================
def calculate_dct_capacity(image_data: bytes) -> DCTCapacityInfo: def calculate_dct_capacity(image_data: bytes) -> DCTCapacityInfo:
"""Calculate DCT embedding capacity of an image.""" """Calculate DCT embedding capacity of an image."""
_check_scipy() _check_scipy()
@@ -405,7 +447,7 @@ def calculate_dct_capacity(image_data: bytes) -> DCTCapacityInfo:
bits_per_block=bits_per_block, bits_per_block=bits_per_block,
total_capacity_bits=total_bits, total_capacity_bits=total_bits,
total_capacity_bytes=total_bytes, total_capacity_bytes=total_bytes,
usable_capacity_bytes=usable_bytes usable_capacity_bytes=usable_bytes,
) )
@@ -427,24 +469,24 @@ def estimate_capacity_comparison(image_data: bytes) -> dict:
dct_bytes = (blocks * 16) // 8 - HEADER_SIZE dct_bytes = (blocks * 16) // 8 - HEADER_SIZE
return { return {
'width': width, "width": width,
'height': height, "height": height,
'lsb': { "lsb": {
'capacity_bytes': lsb_bytes, "capacity_bytes": lsb_bytes,
'capacity_kb': lsb_bytes / 1024, "capacity_kb": lsb_bytes / 1024,
'output': 'PNG/BMP (color)', "output": "PNG/BMP (color)",
}, },
'dct': { "dct": {
'capacity_bytes': dct_bytes, "capacity_bytes": dct_bytes,
'capacity_kb': dct_bytes / 1024, "capacity_kb": dct_bytes / 1024,
'output': 'PNG or JPEG (grayscale)', "output": "PNG or JPEG (grayscale)",
'ratio_vs_lsb': (dct_bytes / lsb_bytes * 100) if lsb_bytes > 0 else 0, "ratio_vs_lsb": (dct_bytes / lsb_bytes * 100) if lsb_bytes > 0 else 0,
'available': HAS_SCIPY, "available": HAS_SCIPY,
},
"jpeg_native": {
"available": HAS_JPEGIO,
"note": "Uses jpegio for proper JPEG coefficient embedding",
}, },
'jpeg_native': {
'available': HAS_JPEGIO,
'note': 'Uses jpegio for proper JPEG coefficient embedding',
}
} }
@@ -453,14 +495,14 @@ def embed_in_dct(
carrier_image: bytes, carrier_image: bytes,
seed: bytes, seed: bytes,
output_format: str = OUTPUT_FORMAT_PNG, output_format: str = OUTPUT_FORMAT_PNG,
color_mode: str = 'color', color_mode: str = "color",
) -> tuple[bytes, DCTEmbedStats]: ) -> tuple[bytes, DCTEmbedStats]:
"""Embed data using DCT coefficient modification.""" """Embed data using DCT coefficient modification."""
if output_format not in (OUTPUT_FORMAT_PNG, OUTPUT_FORMAT_JPEG): if output_format not in (OUTPUT_FORMAT_PNG, OUTPUT_FORMAT_JPEG):
raise ValueError(f"Invalid output format: {output_format}") raise ValueError(f"Invalid output format: {output_format}")
if color_mode not in ('color', 'grayscale'): if color_mode not in ("color", "grayscale"):
color_mode = 'color' color_mode = "color"
if output_format == OUTPUT_FORMAT_JPEG and HAS_JPEGIO: if output_format == OUTPUT_FORMAT_JPEG and HAS_JPEGIO:
return _embed_jpegio(data, carrier_image, seed, color_mode) return _embed_jpegio(data, carrier_image, seed, color_mode)
@@ -474,7 +516,7 @@ def _embed_scipy_dct_safe(
carrier_image: bytes, carrier_image: bytes,
seed: bytes, seed: bytes,
output_format: str, output_format: str,
color_mode: str = 'color', color_mode: str = "color",
) -> tuple[bytes, DCTEmbedStats]: ) -> tuple[bytes, DCTEmbedStats]:
""" """
Embed using scipy DCT with safe memory handling. Embed using scipy DCT with safe memory handling.
@@ -494,7 +536,7 @@ def _embed_scipy_dct_safe(
img = Image.open(io.BytesIO(carrier_image)) img = Image.open(io.BytesIO(carrier_image))
width, height = img.size width, height = img.size
flags = FLAG_COLOR_MODE if color_mode == 'color' else 0 flags = FLAG_COLOR_MODE if color_mode == "color" else 0
# Prepare payload bits # Prepare payload bits
header = _create_header(len(data), flags) header = _create_header(len(data), flags)
@@ -509,12 +551,12 @@ def _embed_scipy_dct_safe(
block_order = _generate_block_order(num_blocks, seed) block_order = _generate_block_order(num_blocks, seed)
blocks_x = width // BLOCK_SIZE blocks_x = width // BLOCK_SIZE
if color_mode == 'color' and img.mode in ('RGB', 'RGBA'): if color_mode == "color" and img.mode in ("RGB", "RGBA"):
if img.mode == 'RGBA': if img.mode == "RGBA":
img = img.convert('RGB') img = img.convert("RGB")
# Process color image # Process color image
rgb = np.array(img, dtype=np.float64, copy=True, order='C') rgb = np.array(img, dtype=np.float64, copy=True, order="C")
img.close() img.close()
Y, Cb, Cr = _rgb_to_ycbcr(rgb) Y, Cb, Cr = _rgb_to_ycbcr(rgb)
@@ -592,7 +634,7 @@ def _embed_in_channel_safe(
h, w = channel.shape h, w = channel.shape
# Create result with explicit new memory # Create result with explicit new memory
result = np.array(channel, dtype=np.float64, copy=True, order='C') result = np.array(channel, dtype=np.float64, copy=True, order="C")
bit_idx = 0 bit_idx = 0
@@ -606,7 +648,9 @@ def _embed_in_channel_safe(
# Extract block - create brand new array # Extract block - create brand new array
block = np.array( block = np.array(
result[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE], result[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE],
dtype=np.float64, copy=True, order='C' dtype=np.float64,
copy=True,
order="C",
) )
# Apply safe DCT (row-by-row) # Apply safe DCT (row-by-row)
@@ -617,8 +661,7 @@ def _embed_in_channel_safe(
if bit_idx >= len(bits): if bit_idx >= len(bits):
break break
dct_block[pos[0], pos[1]] = _embed_bit_in_coeff( dct_block[pos[0], pos[1]] = _embed_bit_in_coeff(
float(dct_block[pos[0], pos[1]]), float(dct_block[pos[0], pos[1]]), bits[bit_idx]
bits[bit_idx]
) )
bit_idx += 1 bit_idx += 1
@@ -654,13 +697,13 @@ def _normalize_jpeg_for_jpegio(image_data: bytes) -> bytes:
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
# Only process JPEGs # Only process JPEGs
if img.format != 'JPEG': if img.format != "JPEG":
img.close() img.close()
return image_data return image_data
# Check quantization tables # Check quantization tables
needs_normalization = False needs_normalization = False
if hasattr(img, 'quantization') and img.quantization: if hasattr(img, "quantization") and img.quantization:
for table_id, table in img.quantization.items(): for table_id, table in img.quantization.items():
# If all values in any table are <= threshold, normalize # If all values in any table are <= threshold, normalize
if max(table) <= JPEGIO_MAX_QUANT_VALUE_THRESHOLD: if max(table) <= JPEGIO_MAX_QUANT_VALUE_THRESHOLD:
@@ -672,11 +715,11 @@ def _normalize_jpeg_for_jpegio(image_data: bytes) -> bytes:
return image_data return image_data
# Re-save at safe quality level # Re-save at safe quality level
if img.mode != 'RGB': if img.mode != "RGB":
img = img.convert('RGB') img = img.convert("RGB")
buffer = io.BytesIO() buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=JPEGIO_NORMALIZE_QUALITY, subsampling=0) img.save(buffer, format="JPEG", quality=JPEGIO_NORMALIZE_QUALITY, subsampling=0)
img.close() img.close()
return buffer.getvalue() return buffer.getvalue()
@@ -686,7 +729,7 @@ def _embed_jpegio(
data: bytes, data: bytes,
carrier_image: bytes, carrier_image: bytes,
seed: bytes, seed: bytes,
color_mode: str = 'color', color_mode: str = "color",
) -> tuple[bytes, DCTEmbedStats]: ) -> tuple[bytes, DCTEmbedStats]:
"""Embed using jpegio for proper JPEG coefficient modification.""" """Embed using jpegio for proper JPEG coefficient modification."""
import os import os
@@ -698,18 +741,18 @@ def _embed_jpegio(
img = Image.open(io.BytesIO(carrier_image)) img = Image.open(io.BytesIO(carrier_image))
width, height = img.size width, height = img.size
if img.format != 'JPEG': if img.format != "JPEG":
buffer = io.BytesIO() buffer = io.BytesIO()
if img.mode != 'RGB': if img.mode != "RGB":
img = img.convert('RGB') img = img.convert("RGB")
img.save(buffer, format='JPEG', quality=95, subsampling=0) img.save(buffer, format="JPEG", quality=95, subsampling=0)
carrier_image = buffer.getvalue() carrier_image = buffer.getvalue()
img.close() img.close()
input_path = _jpegio_bytes_to_file(carrier_image, suffix='.jpg') input_path = _jpegio_bytes_to_file(carrier_image, suffix=".jpg")
output_path = tempfile.mktemp(suffix='.jpg') output_path = tempfile.mktemp(suffix=".jpg")
flags = FLAG_COLOR_MODE if color_mode == 'color' else 0 flags = FLAG_COLOR_MODE if color_mode == "color" else 0
try: try:
jpeg = jio.read(input_path) jpeg = jio.read(input_path)
@@ -750,7 +793,7 @@ def _embed_jpegio(
jio.write(jpeg, output_path) jio.write(jpeg, output_path)
with open(output_path, 'rb') as f: with open(output_path, "rb") as f:
stego_bytes = f.read() stego_bytes = f.read()
stats = DCTEmbedStats( stats = DCTEmbedStats(
@@ -782,7 +825,7 @@ def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes:
fmt = img.format fmt = img.format
img.close() img.close()
if fmt == 'JPEG' and HAS_JPEGIO: if fmt == "JPEG" and HAS_JPEGIO:
try: try:
return _extract_jpegio(stego_image, seed) return _extract_jpegio(stego_image, seed)
except ValueError: except ValueError:
@@ -798,7 +841,7 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
width, height = img.size width, height = img.size
mode = img.mode mode = img.mode
if mode in ('RGB', 'RGBA'): if mode in ("RGB", "RGBA"):
channel = _extract_y_channel(stego_image) channel = _extract_y_channel(stego_image)
else: else:
channel = _to_grayscale(stego_image) channel = _to_grayscale(stego_image)
@@ -822,7 +865,9 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
block = np.array( block = np.array(
padded[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE], padded[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE],
dtype=np.float64, copy=True, order='C' dtype=np.float64,
copy=True,
order="C",
) )
dct_block = _safe_dct2(block) dct_block = _safe_dct2(block)
@@ -847,10 +892,12 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
_, flags, data_length = _parse_header(all_bits) _, flags, data_length = _parse_header(all_bits)
data_bits = all_bits[HEADER_SIZE * 8 : (HEADER_SIZE + data_length) * 8] data_bits = all_bits[HEADER_SIZE * 8 : (HEADER_SIZE + data_length) * 8]
data = bytes([ data = bytes(
[
sum(data_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8)) sum(data_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8))
for i in range(data_length) for i in range(data_length)
]) ]
)
return data return data
@@ -863,7 +910,7 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
# (shouldn't happen with stego images, but be defensive) # (shouldn't happen with stego images, but be defensive)
stego_image = _normalize_jpeg_for_jpegio(stego_image) stego_image = _normalize_jpeg_for_jpegio(stego_image)
temp_path = _jpegio_bytes_to_file(stego_image, suffix='.jpg') temp_path = _jpegio_bytes_to_file(stego_image, suffix=".jpg")
try: try:
jpeg = jio.read(temp_path) jpeg = jio.read(temp_path)
@@ -878,10 +925,12 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
coef = coef_array[row, col] coef = coef_array[row, col]
header_bits.append(coef & 1) header_bits.append(coef & 1)
header_bytes = bytes([ header_bytes = bytes(
[
sum(header_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8)) sum(header_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8))
for i in range(HEADER_SIZE) for i in range(HEADER_SIZE)
]) ]
)
_, flags, data_length = _jpegio_parse_header(header_bytes) _, flags, data_length = _jpegio_parse_header(header_bytes)
@@ -897,10 +946,12 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
data_bits = all_bits[HEADER_SIZE * 8 :] data_bits = all_bits[HEADER_SIZE * 8 :]
data = bytes([ data = bytes(
[
sum(data_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8)) sum(data_bits[i * 8 : (i + 1) * 8][j] << (7 - j) for j in range(8))
for i in range(data_length) for i in range(data_length)
]) ]
)
return data return data
@@ -915,13 +966,14 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
# CONVENIENCE FUNCTIONS # CONVENIENCE FUNCTIONS
# ============================================================================ # ============================================================================
def get_output_extension(output_format: str) -> str: def get_output_extension(output_format: str) -> str:
if output_format == OUTPUT_FORMAT_JPEG: if output_format == OUTPUT_FORMAT_JPEG:
return '.jpg' return ".jpg"
return '.png' return ".png"
def get_output_mimetype(output_format: str) -> str: def get_output_mimetype(output_format: str) -> str:
if output_format == OUTPUT_FORMAT_JPEG: if output_format == OUTPUT_FORMAT_JPEG:
return 'image/jpeg' return "image/jpeg"
return 'image/png' return "image/png"

View File

@@ -68,6 +68,7 @@ def debug_exception(e: Exception, context: str = "") -> None:
def time_function(func: Callable) -> Callable: def time_function(func: Callable) -> Callable:
"""Decorator to time function execution for performance debugging.""" """Decorator to time function execution for performance debugging."""
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs) -> Any: def wrapper(*args, **kwargs) -> Any:
if not (DEBUG_ENABLED and LOG_PERFORMANCE): if not (DEBUG_ENABLED and LOG_PERFORMANCE):
@@ -96,16 +97,17 @@ def memory_usage() -> dict[str, float | str]:
import os import os
import psutil import psutil
process = psutil.Process(os.getpid()) process = psutil.Process(os.getpid())
mem_info = process.memory_info() mem_info = process.memory_info()
return { return {
'rss_mb': mem_info.rss / 1024 / 1024, "rss_mb": mem_info.rss / 1024 / 1024,
'vms_mb': mem_info.vms / 1024 / 1024, "vms_mb": mem_info.vms / 1024 / 1024,
'percent': process.memory_percent(), "percent": process.memory_percent(),
} }
except ImportError: except ImportError:
return {'error': 'psutil not installed'} return {"error": "psutil not installed"}
def hexdump(data: bytes, offset: int = 0, length: int = 64) -> str: def hexdump(data: bytes, offset: int = 0, length: int = 64) -> str:
@@ -118,15 +120,15 @@ def hexdump(data: bytes, offset: int = 0, length: int = 64) -> str:
for i in range(0, len(data_to_dump), 16): for i in range(0, len(data_to_dump), 16):
chunk = data_to_dump[i : i + 16] chunk = data_to_dump[i : i + 16]
hex_str = ' '.join(f'{b:02x}' for b in chunk) hex_str = " ".join(f"{b:02x}" for b in chunk)
hex_str = hex_str.ljust(47) hex_str = hex_str.ljust(47)
ascii_str = ''.join(chr(b) if 32 <= b < 127 else '.' for b in chunk) ascii_str = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk)
result.append(f"{offset + i:08x}: {hex_str} {ascii_str}") result.append(f"{offset + i:08x}: {hex_str} {ascii_str}")
if len(data) > length: if len(data) > length:
result.append(f"... ({len(data) - length} more bytes)") result.append(f"... ({len(data) - length} more bytes)")
return '\n'.join(result) return "\n".join(result)
class Debug: class Debug:

View File

@@ -75,9 +75,11 @@ def decode(
... channel_key="ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456" ... channel_key="ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456"
... ) ... )
""" """
debug.print(f"decode: passphrase length={len(passphrase.split())} words, " debug.print(
f"decode: passphrase length={len(passphrase.split())} words, "
f"mode={embed_mode}, " f"mode={embed_mode}, "
f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}") f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}"
)
# Validate inputs # Validate inputs
require_valid_image(stego_image, "Stego image") require_valid_image(stego_image, "Stego image")
@@ -91,9 +93,8 @@ def decode(
# Derive pixel/coefficient selection key (with channel key) # Derive pixel/coefficient selection key (with channel key)
from .crypto import derive_pixel_key from .crypto import derive_pixel_key
pixel_key = derive_pixel_key(
reference_photo, passphrase, pin, rsa_key_data, channel_key pixel_key = derive_pixel_key(reference_photo, passphrase, pin, rsa_key_data, channel_key)
)
# Extract encrypted data # Extract encrypted data
encrypted = extract_from_image( encrypted = extract_from_image(
@@ -109,9 +110,7 @@ def decode(
debug.print(f"Extracted {len(encrypted)} bytes from image") debug.print(f"Extracted {len(encrypted)} bytes from image")
# Decrypt (with channel key) # Decrypt (with channel key)
result = decrypt_message( result = decrypt_message(encrypted, reference_photo, passphrase, pin, rsa_key_data, channel_key)
encrypted, reference_photo, passphrase, pin, rsa_key_data, channel_key
)
debug.print(f"Decryption successful: {result.payload_type}") debug.print(f"Decryption successful: {result.payload_type}")
return result return result
@@ -222,7 +221,7 @@ def decode_text(
# Try to decode as text # Try to decode as text
if result.file_data: if result.file_data:
try: try:
return result.file_data.decode('utf-8') return result.file_data.decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
raise DecryptionError( raise DecryptionError(
f"Payload is a binary file ({result.filename or 'unnamed'}), not text" f"Payload is a binary file ({result.filename or 'unnamed'}), not text"

View File

@@ -82,9 +82,11 @@ def encode(
... channel_key="ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456" ... channel_key="ABCD-1234-EFGH-5678-IJKL-9012-MNOP-3456"
... ) ... )
""" """
debug.print(f"encode: passphrase length={len(passphrase.split())} words, " debug.print(
f"encode: passphrase length={len(passphrase.split())} words, "
f"pin={'set' if pin else 'none'}, mode={embed_mode}, " f"pin={'set' if pin else 'none'}, mode={embed_mode}, "
f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}") f"channel_key={'explicit' if isinstance(channel_key, str) and channel_key else 'auto' if channel_key is None else 'none'}"
)
# Validate inputs # Validate inputs
require_valid_payload(message) require_valid_payload(message)
@@ -105,9 +107,7 @@ def encode(
debug.print(f"Encrypted payload: {len(encrypted)} bytes") debug.print(f"Encrypted payload: {len(encrypted)} bytes")
# Derive pixel/coefficient selection key (with channel key) # Derive pixel/coefficient selection key (with channel key)
pixel_key = derive_pixel_key( pixel_key = derive_pixel_key(reference_photo, passphrase, pin, rsa_key_data, channel_key)
reference_photo, passphrase, pin, rsa_key_data, channel_key
)
# Embed in image # Embed in image
stego_data, stats, extension = embed_in_image( stego_data, stats, extension = embed_in_image(
@@ -124,7 +124,7 @@ def encode(
filename = generate_filename(extension=extension) filename = generate_filename(extension=extension)
# Create result # Create result
if hasattr(stats, 'pixels_modified'): if hasattr(stats, "pixels_modified"):
# LSB mode stats # LSB mode stats
return EncodeResult( return EncodeResult(
stego_image=stego_data, stego_image=stego_data,

View File

@@ -7,6 +7,7 @@ Custom exception classes for clear error handling across all frontends.
class StegasooError(Exception): class StegasooError(Exception):
"""Base exception for all Stegasoo errors.""" """Base exception for all Stegasoo errors."""
pass pass
@@ -14,33 +15,40 @@ class StegasooError(Exception):
# VALIDATION ERRORS # VALIDATION ERRORS
# ============================================================================ # ============================================================================
class ValidationError(StegasooError): class ValidationError(StegasooError):
"""Base class for validation errors.""" """Base class for validation errors."""
pass pass
class PinValidationError(ValidationError): class PinValidationError(ValidationError):
"""PIN validation failed.""" """PIN validation failed."""
pass pass
class MessageValidationError(ValidationError): class MessageValidationError(ValidationError):
"""Message validation failed.""" """Message validation failed."""
pass pass
class ImageValidationError(ValidationError): class ImageValidationError(ValidationError):
"""Image validation failed.""" """Image validation failed."""
pass pass
class KeyValidationError(ValidationError): class KeyValidationError(ValidationError):
"""RSA key validation failed.""" """RSA key validation failed."""
pass pass
class SecurityFactorError(ValidationError): class SecurityFactorError(ValidationError):
"""Security factor requirements not met.""" """Security factor requirements not met."""
pass pass
@@ -48,33 +56,40 @@ class SecurityFactorError(ValidationError):
# CRYPTO ERRORS # CRYPTO ERRORS
# ============================================================================ # ============================================================================
class CryptoError(StegasooError): class CryptoError(StegasooError):
"""Base class for cryptographic errors.""" """Base class for cryptographic errors."""
pass pass
class EncryptionError(CryptoError): class EncryptionError(CryptoError):
"""Encryption failed.""" """Encryption failed."""
pass pass
class DecryptionError(CryptoError): class DecryptionError(CryptoError):
"""Decryption failed (wrong key, corrupted data, etc.).""" """Decryption failed (wrong key, corrupted data, etc.)."""
pass pass
class KeyDerivationError(CryptoError): class KeyDerivationError(CryptoError):
"""Key derivation failed.""" """Key derivation failed."""
pass pass
class KeyGenerationError(CryptoError): class KeyGenerationError(CryptoError):
"""Key generation failed.""" """Key generation failed."""
pass pass
class KeyPasswordError(CryptoError): class KeyPasswordError(CryptoError):
"""RSA key password is incorrect or missing.""" """RSA key password is incorrect or missing."""
pass pass
@@ -82,8 +97,10 @@ class KeyPasswordError(CryptoError):
# STEGANOGRAPHY ERRORS # STEGANOGRAPHY ERRORS
# ============================================================================ # ============================================================================
class SteganographyError(StegasooError): class SteganographyError(StegasooError):
"""Base class for steganography errors.""" """Base class for steganography errors."""
pass pass
@@ -100,16 +117,19 @@ class CapacityError(SteganographyError):
class ExtractionError(SteganographyError): class ExtractionError(SteganographyError):
"""Failed to extract hidden data from image.""" """Failed to extract hidden data from image."""
pass pass
class EmbeddingError(SteganographyError): class EmbeddingError(SteganographyError):
"""Failed to embed data in image.""" """Failed to embed data in image."""
pass pass
class InvalidHeaderError(SteganographyError): class InvalidHeaderError(SteganographyError):
"""Invalid or missing Stegasoo header in extracted data.""" """Invalid or missing Stegasoo header in extracted data."""
pass pass
@@ -117,13 +137,16 @@ class InvalidHeaderError(SteganographyError):
# FILE ERRORS # FILE ERRORS
# ============================================================================ # ============================================================================
class FileError(StegasooError): class FileError(StegasooError):
"""Base class for file-related errors.""" """Base class for file-related errors."""
pass pass
class FileNotFoundError(FileError): class FileNotFoundError(FileError):
"""Required file not found.""" """Required file not found."""
pass pass

View File

@@ -4,7 +4,6 @@ Stegasoo Generate Module (v3.2.0)
Public API for generating credentials (PINs, passphrases, RSA keys). Public API for generating credentials (PINs, passphrases, RSA keys).
""" """
from .constants import ( from .constants import (
DEFAULT_PASSPHRASE_WORDS, DEFAULT_PASSPHRASE_WORDS,
DEFAULT_PIN_LENGTH, DEFAULT_PIN_LENGTH,
@@ -26,12 +25,12 @@ from .models import Credentials
# Re-export from keygen for convenience # Re-export from keygen for convenience
__all__ = [ __all__ = [
'generate_pin', "generate_pin",
'generate_passphrase', "generate_passphrase",
'generate_rsa_key', "generate_rsa_key",
'generate_credentials', "generate_credentials",
'export_rsa_key_pem', "export_rsa_key_pem",
'load_rsa_key', "load_rsa_key",
] ]
@@ -78,10 +77,7 @@ def generate_passphrase(words: int = DEFAULT_PASSPHRASE_WORDS) -> str:
return generate_phrase(words) return generate_phrase(words)
def generate_rsa_key( def generate_rsa_key(bits: int = DEFAULT_RSA_BITS, password: str | None = None) -> str:
bits: int = DEFAULT_RSA_BITS,
password: str | None = None
) -> str:
""" """
Generate an RSA private key in PEM format. Generate an RSA private key in PEM format.
@@ -99,7 +95,7 @@ def generate_rsa_key(
""" """
key_obj = _generate_rsa_key(bits) key_obj = _generate_rsa_key(bits)
pem_bytes = export_rsa_key_pem(key_obj, password) pem_bytes = export_rsa_key_pem(key_obj, password)
return pem_bytes.decode('utf-8') return pem_bytes.decode("utf-8")
def generate_credentials( def generate_credentials(
@@ -140,8 +136,10 @@ def generate_credentials(
if not use_pin and not use_rsa: if not use_pin and not use_rsa:
raise ValueError("Must select at least one security factor (PIN or RSA key)") raise ValueError("Must select at least one security factor (PIN or RSA key)")
debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, " debug.print(
f"passphrase_words={passphrase_words}") f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, "
f"passphrase_words={passphrase_words}"
)
# Generate passphrase (single, not daily) # Generate passphrase (single, not daily)
passphrase = generate_phrase(passphrase_words) passphrase = generate_phrase(passphrase_words)
@@ -154,7 +152,7 @@ def generate_credentials(
if use_rsa: if use_rsa:
rsa_key_obj = _generate_rsa_key(rsa_bits) rsa_key_obj = _generate_rsa_key(rsa_bits)
rsa_key_bytes = export_rsa_key_pem(rsa_key_obj, rsa_password) rsa_key_bytes = export_rsa_key_pem(rsa_key_obj, rsa_password)
rsa_key_pem = rsa_key_bytes.decode('utf-8') rsa_key_pem = rsa_key_bytes.decode("utf-8")
# Create Credentials object (v3.2.0 format) # Create Credentials object (v3.2.0 format)
creds = Credentials( creds = Credentials(

View File

@@ -43,6 +43,7 @@ def get_image_info(image_data: bytes) -> ImageInfo:
if has_dct_support(): if has_dct_support():
try: try:
from .dct_steganography import calculate_dct_capacity from .dct_steganography import calculate_dct_capacity
dct_info = calculate_dct_capacity(image_data) dct_info = calculate_dct_capacity(image_data)
dct_capacity = dct_info.usable_capacity_bytes dct_capacity = dct_info.usable_capacity_bytes
except Exception as e: except Exception as e:
@@ -61,8 +62,10 @@ def get_image_info(image_data: bytes) -> ImageInfo:
dct_capacity_kb=dct_capacity / 1024 if dct_capacity else None, dct_capacity_kb=dct_capacity / 1024 if dct_capacity else None,
) )
debug.print(f"Image info: {width}x{height}, LSB={lsb_capacity} bytes, " debug.print(
f"DCT={dct_capacity or 'N/A'} bytes") f"Image info: {width}x{height}, LSB={lsb_capacity} bytes, "
f"DCT={dct_capacity or 'N/A'} bytes"
)
return info return info
@@ -101,6 +104,7 @@ def compare_capacity(
if dct_available: if dct_available:
try: try:
from .dct_steganography import calculate_dct_capacity from .dct_steganography import calculate_dct_capacity
dct_info = calculate_dct_capacity(carrier_image) dct_info = calculate_dct_capacity(carrier_image)
dct_bytes = dct_info.usable_capacity_bytes dct_bytes = dct_info.usable_capacity_bytes
dct_kb = dct_bytes / 1024 dct_kb = dct_bytes / 1024
@@ -146,7 +150,7 @@ def validate_carrier_capacity(
from .steganography import calculate_capacity_by_mode from .steganography import calculate_capacity_by_mode
capacity_info = calculate_capacity_by_mode(carrier_image, embed_mode) capacity_info = calculate_capacity_by_mode(carrier_image, embed_mode)
capacity = capacity_info['capacity_bytes'] capacity = capacity_info["capacity_bytes"]
# Add encryption overhead estimate # Add encryption overhead estimate
estimated_size = payload_size + 200 # Approximate overhead estimated_size = payload_size + 200 # Approximate overhead
@@ -156,11 +160,11 @@ def validate_carrier_capacity(
headroom = capacity - estimated_size headroom = capacity - estimated_size
return { return {
'fits': fits, "fits": fits,
'capacity': capacity, "capacity": capacity,
'payload_size': payload_size, "payload_size": payload_size,
'estimated_size': estimated_size, "estimated_size": estimated_size,
'usage_percent': min(usage_percent, 100.0), "usage_percent": min(usage_percent, 100.0),
'headroom': headroom, "headroom": headroom,
'mode': embed_mode, "mode": embed_mode,
} }

View File

@@ -50,8 +50,10 @@ def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
>>> generate_pin(6) >>> generate_pin(6)
"812345" "812345"
""" """
debug.validate(MIN_PIN_LENGTH <= length <= MAX_PIN_LENGTH, debug.validate(
f"PIN length must be between {MIN_PIN_LENGTH} and {MAX_PIN_LENGTH}") MIN_PIN_LENGTH <= length <= MAX_PIN_LENGTH,
f"PIN length must be between {MIN_PIN_LENGTH} and {MAX_PIN_LENGTH}",
)
length = max(MIN_PIN_LENGTH, min(MAX_PIN_LENGTH, length)) length = max(MIN_PIN_LENGTH, min(MAX_PIN_LENGTH, length))
@@ -59,7 +61,7 @@ def generate_pin(length: int = DEFAULT_PIN_LENGTH) -> str:
first_digit = str(secrets.randbelow(9) + 1) first_digit = str(secrets.randbelow(9) + 1)
# Remaining digits: 0-9 # Remaining digits: 0-9
rest = ''.join(str(secrets.randbelow(10)) for _ in range(length - 1)) rest = "".join(str(secrets.randbelow(10)) for _ in range(length - 1))
pin = first_digit + rest pin = first_digit + rest
debug.print(f"Generated PIN: {pin}") debug.print(f"Generated PIN: {pin}")
@@ -80,14 +82,16 @@ def generate_phrase(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> str:
>>> generate_phrase(4) >>> generate_phrase(4)
"apple forest thunder mountain" "apple forest thunder mountain"
""" """
debug.validate(MIN_PASSPHRASE_WORDS <= words_per_phrase <= MAX_PASSPHRASE_WORDS, debug.validate(
f"Words per phrase must be between {MIN_PASSPHRASE_WORDS} and {MAX_PASSPHRASE_WORDS}") MIN_PASSPHRASE_WORDS <= words_per_phrase <= MAX_PASSPHRASE_WORDS,
f"Words per phrase must be between {MIN_PASSPHRASE_WORDS} and {MAX_PASSPHRASE_WORDS}",
)
words_per_phrase = max(MIN_PASSPHRASE_WORDS, min(MAX_PASSPHRASE_WORDS, words_per_phrase)) words_per_phrase = max(MIN_PASSPHRASE_WORDS, min(MAX_PASSPHRASE_WORDS, words_per_phrase))
wordlist = get_wordlist() wordlist = get_wordlist()
words = [secrets.choice(wordlist) for _ in range(words_per_phrase)] words = [secrets.choice(wordlist) for _ in range(words_per_phrase)]
phrase = ' '.join(words) phrase = " ".join(words)
debug.print(f"Generated phrase: {phrase}") debug.print(f"Generated phrase: {phrase}")
return phrase return phrase
@@ -114,11 +118,12 @@ def generate_day_phrases(words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS) -> di
{'Monday': 'apple forest thunder', 'Tuesday': 'banana river lightning', ...} {'Monday': 'apple forest thunder', 'Tuesday': 'banana river lightning', ...}
""" """
import warnings import warnings
warnings.warn( warnings.warn(
"generate_day_phrases() is deprecated in v3.2.0. " "generate_day_phrases() is deprecated in v3.2.0. "
"Use generate_phrase() for single passphrase.", "Use generate_phrase() for single passphrase.",
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2,
) )
phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES} phrases = {day: generate_phrase(words_per_phrase) for day in DAY_NAMES}
@@ -144,8 +149,7 @@ def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
>>> key.key_size >>> key.key_size
2048 2048
""" """
debug.validate(bits in VALID_RSA_SIZES, debug.validate(bits in VALID_RSA_SIZES, f"RSA key size must be one of {VALID_RSA_SIZES}")
f"RSA key size must be one of {VALID_RSA_SIZES}")
if bits not in VALID_RSA_SIZES: if bits not in VALID_RSA_SIZES:
bits = DEFAULT_RSA_BITS bits = DEFAULT_RSA_BITS
@@ -153,9 +157,7 @@ def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
debug.print(f"Generating {bits}-bit RSA key...") debug.print(f"Generating {bits}-bit RSA key...")
try: try:
key = rsa.generate_private_key( key = rsa.generate_private_key(
public_exponent=65537, public_exponent=65537, key_size=bits, backend=default_backend()
key_size=bits,
backend=default_backend()
) )
debug.print(f"RSA key generated: {bits} bits") debug.print(f"RSA key generated: {bits} bits")
return key return key
@@ -164,10 +166,7 @@ def generate_rsa_key(bits: int = DEFAULT_RSA_BITS) -> rsa.RSAPrivateKey:
raise KeyGenerationError(f"Failed to generate RSA key: {e}") from e raise KeyGenerationError(f"Failed to generate RSA key: {e}") from e
def export_rsa_key_pem( def export_rsa_key_pem(private_key: rsa.RSAPrivateKey, password: str | None = None) -> bytes:
private_key: rsa.RSAPrivateKey,
password: str | None = None
) -> bytes:
""" """
Export RSA key to PEM format. Export RSA key to PEM format.
@@ -198,14 +197,11 @@ def export_rsa_key_pem(
return private_key.private_bytes( return private_key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8, format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm encryption_algorithm=encryption_algorithm,
) )
def load_rsa_key( def load_rsa_key(key_data: bytes, password: str | None = None) -> rsa.RSAPrivateKey:
key_data: bytes,
password: str | None = None
) -> rsa.RSAPrivateKey:
""" """
Load RSA private key from PEM data. Load RSA private key from PEM data.
@@ -223,8 +219,7 @@ def load_rsa_key(
Example: Example:
>>> key = load_rsa_key(pem_data, "my_password") >>> key = load_rsa_key(pem_data, "my_password")
""" """
debug.validate(key_data is not None and len(key_data) > 0, debug.validate(key_data is not None and len(key_data) > 0, "Key data cannot be empty")
"Key data cannot be empty")
try: try:
pwd_bytes = password.encode() if password else None pwd_bytes = password.encode() if password else None
@@ -274,15 +269,11 @@ def get_key_info(key_data: bytes, password: str | None = None) -> KeyInfo:
""" """
debug.print("Getting RSA key info") debug.print("Getting RSA key info")
# Check if encrypted # Check if encrypted
is_encrypted = b'ENCRYPTED' in key_data is_encrypted = b"ENCRYPTED" in key_data
private_key = load_rsa_key(key_data, password) private_key = load_rsa_key(key_data, password)
info = KeyInfo( info = KeyInfo(key_size=private_key.key_size, is_encrypted=is_encrypted, pem_data=key_data)
key_size=private_key.key_size,
is_encrypted=is_encrypted,
pem_data=key_data
)
debug.print(f"Key info: {info.key_size} bits, encrypted: {info.is_encrypted}") debug.print(f"Key info: {info.key_size} bits, encrypted: {info.is_encrypted}")
return info return info
@@ -323,14 +314,15 @@ def generate_credentials(
>>> creds.pin >>> creds.pin
"812345" "812345"
""" """
debug.validate(use_pin or use_rsa, debug.validate(use_pin or use_rsa, "Must select at least one security factor (PIN or RSA key)")
"Must select at least one security factor (PIN or RSA key)")
if not use_pin and not use_rsa: if not use_pin and not use_rsa:
raise ValueError("Must select at least one security factor (PIN or RSA key)") raise ValueError("Must select at least one security factor (PIN or RSA key)")
debug.print(f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, " debug.print(
f"passphrase_words={passphrase_words}") f"Generating credentials: PIN={use_pin}, RSA={use_rsa}, "
f"passphrase_words={passphrase_words}"
)
# Generate single passphrase (v3.2.0 - no daily rotation) # Generate single passphrase (v3.2.0 - no daily rotation)
passphrase = generate_phrase(passphrase_words) passphrase = generate_phrase(passphrase_words)
@@ -342,7 +334,7 @@ def generate_credentials(
rsa_key_pem = None rsa_key_pem = None
if use_rsa: if use_rsa:
rsa_key_obj = generate_rsa_key(rsa_bits) rsa_key_obj = generate_rsa_key(rsa_bits)
rsa_key_pem = export_rsa_key_pem(rsa_key_obj, rsa_password).decode('utf-8') rsa_key_pem = export_rsa_key_pem(rsa_key_obj, rsa_password).decode("utf-8")
# Create Credentials object (v3.2.0 format with single passphrase) # Create Credentials object (v3.2.0 format with single passphrase)
creds = Credentials( creds = Credentials(
@@ -361,12 +353,13 @@ def generate_credentials(
# LEGACY COMPATIBILITY # LEGACY COMPATIBILITY
# ============================================================================= # =============================================================================
def generate_credentials_legacy( def generate_credentials_legacy(
use_pin: bool = True, use_pin: bool = True,
use_rsa: bool = False, use_rsa: bool = False,
pin_length: int = DEFAULT_PIN_LENGTH, pin_length: int = DEFAULT_PIN_LENGTH,
rsa_bits: int = DEFAULT_RSA_BITS, rsa_bits: int = DEFAULT_RSA_BITS,
words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS words_per_phrase: int = DEFAULT_PASSPHRASE_WORDS,
) -> dict: ) -> dict:
""" """
Generate credentials in legacy format (v3.1.0 style with daily phrases). Generate credentials in legacy format (v3.1.0 style with daily phrases).
@@ -387,11 +380,12 @@ def generate_credentials_legacy(
Dict with 'phrases' (dict), 'pin', 'rsa_key_pem', etc. Dict with 'phrases' (dict), 'pin', 'rsa_key_pem', etc.
""" """
import warnings import warnings
warnings.warn( warnings.warn(
"generate_credentials_legacy() returns v3.1.0 format. " "generate_credentials_legacy() returns v3.1.0 format. "
"Use generate_credentials() for v3.2.0 format.", "Use generate_credentials() for v3.2.0 format.",
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2,
) )
if not use_pin and not use_rsa: if not use_pin and not use_rsa:
@@ -405,12 +399,12 @@ def generate_credentials_legacy(
rsa_key_pem = None rsa_key_pem = None
if use_rsa: if use_rsa:
rsa_key_obj = generate_rsa_key(rsa_bits) rsa_key_obj = generate_rsa_key(rsa_bits)
rsa_key_pem = export_rsa_key_pem(rsa_key_obj).decode('utf-8') rsa_key_pem = export_rsa_key_pem(rsa_key_obj).decode("utf-8")
return { return {
'phrases': phrases, "phrases": phrases,
'pin': pin, "pin": pin,
'rsa_key_pem': rsa_key_pem, "rsa_key_pem": rsa_key_pem,
'rsa_bits': rsa_bits if use_rsa else None, "rsa_bits": rsa_bits if use_rsa else None,
'words_per_phrase': words_per_phrase, "words_per_phrase": words_per_phrase,
} }

View File

@@ -21,6 +21,7 @@ class Credentials:
v3.2.0: Simplified to use single passphrase instead of daily rotation. v3.2.0: Simplified to use single passphrase instead of daily rotation.
""" """
passphrase: str # Single passphrase (no daily rotation) passphrase: str # Single passphrase (no daily rotation)
pin: str | None = None pin: str | None = None
rsa_key_pem: str | None = None rsa_key_pem: str | None = None
@@ -64,6 +65,7 @@ class Credentials:
@dataclass @dataclass
class FilePayload: class FilePayload:
"""Represents a file to be embedded.""" """Represents a file to be embedded."""
data: bytes data: bytes
filename: str filename: str
mime_type: str | None = None mime_type: str | None = None
@@ -73,7 +75,7 @@ class FilePayload:
return len(self.data) return len(self.data)
@classmethod @classmethod
def from_file(cls, filepath: str, filename: str | None = None) -> 'FilePayload': def from_file(cls, filepath: str, filename: str | None = None) -> "FilePayload":
"""Create FilePayload from a file path.""" """Create FilePayload from a file path."""
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
@@ -93,6 +95,7 @@ class EncodeInput:
v3.2.0: Removed date_str (date no longer used in crypto). v3.2.0: Removed date_str (date no longer used in crypto).
""" """
message: str | bytes | FilePayload # Text, raw bytes, or file message: str | bytes | FilePayload # Text, raw bytes, or file
reference_photo: bytes reference_photo: bytes
carrier_image: bytes carrier_image: bytes
@@ -109,6 +112,7 @@ class EncodeResult:
v3.2.0: date_used is now optional/cosmetic (not used in crypto). v3.2.0: date_used is now optional/cosmetic (not used in crypto).
""" """
stego_image: bytes stego_image: bytes
filename: str filename: str
pixels_modified: int pixels_modified: int
@@ -129,6 +133,7 @@ class DecodeInput:
v3.2.0: Renamed day_phrase → passphrase, no date needed. v3.2.0: Renamed day_phrase → passphrase, no date needed.
""" """
stego_image: bytes stego_image: bytes
reference_photo: bytes reference_photo: bytes
passphrase: str # Renamed from day_phrase passphrase: str # Renamed from day_phrase
@@ -144,6 +149,7 @@ class DecodeResult:
v3.2.0: date_encoded is always None (date removed from crypto). v3.2.0: date_encoded is always None (date removed from crypto).
""" """
payload_type: str # 'text' or 'file' payload_type: str # 'text' or 'file'
message: str | None = None # For text payloads message: str | None = None # For text payloads
file_data: bytes | None = None # For file payloads file_data: bytes | None = None # For file payloads
@@ -153,11 +159,11 @@ class DecodeResult:
@property @property
def is_file(self) -> bool: def is_file(self) -> bool:
return self.payload_type == 'file' return self.payload_type == "file"
@property @property
def is_text(self) -> bool: def is_text(self) -> bool:
return self.payload_type == 'text' return self.payload_type == "text"
def get_content(self) -> str | bytes: def get_content(self) -> str | bytes:
"""Get the decoded content (text or bytes).""" """Get the decoded content (text or bytes)."""
@@ -169,6 +175,7 @@ class DecodeResult:
@dataclass @dataclass
class EmbedStats: class EmbedStats:
"""Statistics from image embedding.""" """Statistics from image embedding."""
pixels_modified: int pixels_modified: int
total_pixels: int total_pixels: int
capacity_used: float capacity_used: float
@@ -183,6 +190,7 @@ class EmbedStats:
@dataclass @dataclass
class KeyInfo: class KeyInfo:
"""Information about an RSA key.""" """Information about an RSA key."""
key_size: int key_size: int
is_encrypted: bool is_encrypted: bool
pem_data: bytes pem_data: bytes
@@ -191,13 +199,14 @@ class KeyInfo:
@dataclass @dataclass
class ValidationResult: class ValidationResult:
"""Result of input validation.""" """Result of input validation."""
is_valid: bool is_valid: bool
error_message: str = "" error_message: str = ""
details: dict = field(default_factory=dict) details: dict = field(default_factory=dict)
warning: str | None = None # v3.2.0: Added for passphrase length warnings warning: str | None = None # v3.2.0: Added for passphrase length warnings
@classmethod @classmethod
def ok(cls, warning: str | None = None, **details) -> 'ValidationResult': def ok(cls, warning: str | None = None, **details) -> "ValidationResult":
"""Create a successful validation result.""" """Create a successful validation result."""
result = cls(is_valid=True, details=details) result = cls(is_valid=True, details=details)
if warning: if warning:
@@ -205,7 +214,7 @@ class ValidationResult:
return result return result
@classmethod @classmethod
def error(cls, message: str, **details) -> 'ValidationResult': def error(cls, message: str, **details) -> "ValidationResult":
"""Create a failed validation result.""" """Create a failed validation result."""
return cls(is_valid=False, error_message=message, details=details) return cls(is_valid=False, error_message=message, details=details)
@@ -214,9 +223,11 @@ class ValidationResult:
# NEW MODELS FOR V3.2.0 PUBLIC API # NEW MODELS FOR V3.2.0 PUBLIC API
# ============================================================================= # =============================================================================
@dataclass @dataclass
class ImageInfo: class ImageInfo:
"""Information about an image for steganography.""" """Information about an image for steganography."""
width: int width: int
height: int height: int
pixels: int pixels: int
@@ -232,6 +243,7 @@ class ImageInfo:
@dataclass @dataclass
class CapacityComparison: class CapacityComparison:
"""Comparison of embedding capacity between modes.""" """Comparison of embedding capacity between modes."""
image_width: int image_width: int
image_height: int image_height: int
lsb_available: bool lsb_available: bool
@@ -248,6 +260,7 @@ class CapacityComparison:
@dataclass @dataclass
class GenerateResult: class GenerateResult:
"""Result of credential generation.""" """Result of credential generation."""
passphrase: str passphrase: str
pin: str | None = None pin: str | None = None
rsa_key_pem: str | None = None rsa_key_pem: str | None = None

View File

@@ -20,6 +20,7 @@ from PIL import Image
try: try:
import qrcode import qrcode
from qrcode.constants import ERROR_CORRECT_L, ERROR_CORRECT_M from qrcode.constants import ERROR_CORRECT_L, ERROR_CORRECT_M
HAS_QRCODE_WRITE = True HAS_QRCODE_WRITE = True
except ImportError: except ImportError:
HAS_QRCODE_WRITE = False HAS_QRCODE_WRITE = False
@@ -28,6 +29,7 @@ except ImportError:
try: try:
from pyzbar.pyzbar import ZBarSymbol from pyzbar.pyzbar import ZBarSymbol
from pyzbar.pyzbar import decode as pyzbar_decode from pyzbar.pyzbar import decode as pyzbar_decode
HAS_QRCODE_READ = True HAS_QRCODE_READ = True
except ImportError: except ImportError:
HAS_QRCODE_READ = False HAS_QRCODE_READ = False
@@ -53,8 +55,8 @@ def compress_data(data: str) -> str:
Returns: Returns:
Compressed string with STEGASOO-Z: prefix Compressed string with STEGASOO-Z: prefix
""" """
compressed = zlib.compress(data.encode('utf-8'), level=9) compressed = zlib.compress(data.encode("utf-8"), level=9)
encoded = base64.b64encode(compressed).decode('ascii') encoded = base64.b64encode(compressed).decode("ascii")
return COMPRESSION_PREFIX + encoded return COMPRESSION_PREFIX + encoded
@@ -76,7 +78,7 @@ def decompress_data(data: str) -> str:
encoded = data[len(COMPRESSION_PREFIX) :] encoded = data[len(COMPRESSION_PREFIX) :]
compressed = base64.b64decode(encoded) compressed = base64.b64decode(encoded)
return zlib.decompress(compressed).decode('utf-8') return zlib.decompress(compressed).decode("utf-8")
def normalize_pem(pem_data: str) -> str: def normalize_pem(pem_data: str) -> str:
@@ -101,25 +103,25 @@ def normalize_pem(pem_data: str) -> str:
import re import re
# Step 1: Normalize ALL line endings to \n # Step 1: Normalize ALL line endings to \n
pem_data = pem_data.replace('\r\n', '\n').replace('\r', '\n') pem_data = pem_data.replace("\r\n", "\n").replace("\r", "\n")
# Step 2: Remove leading/trailing whitespace # Step 2: Remove leading/trailing whitespace
pem_data = pem_data.strip() pem_data = pem_data.strip()
# Step 3: Remove any non-ASCII characters (QR artifacts) # Step 3: Remove any non-ASCII characters (QR artifacts)
pem_data = ''.join(char for char in pem_data if ord(char) < 128) pem_data = "".join(char for char in pem_data if ord(char) < 128)
# Step 4: Extract header, content, and footer with flexible regex # Step 4: Extract header, content, and footer with flexible regex
# This handles variations like: # This handles variations like:
# - "PRIVATE KEY" vs "RSA PRIVATE KEY" # - "PRIVATE KEY" vs "RSA PRIVATE KEY"
# - Extra spaces in headers # - Extra spaces in headers
# - Missing spaces # - Missing spaces
pattern = r'(-----BEGIN[^-]*-----)(.*?)(-----END[^-]*-----)' pattern = r"(-----BEGIN[^-]*-----)(.*?)(-----END[^-]*-----)"
match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE) match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE)
if not match: if not match:
# Fallback: try even more permissive pattern # Fallback: try even more permissive pattern
pattern = r'(-+BEGIN[^-]+-+)(.*?)(-+END[^-]+-+)' pattern = r"(-+BEGIN[^-]+-+)(.*?)(-+END[^-]+-+)"
match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE) match = re.search(pattern, pem_data, re.DOTALL | re.IGNORECASE)
if not match: if not match:
@@ -132,38 +134,35 @@ def normalize_pem(pem_data: str) -> str:
# Step 5: Normalize header and footer # Step 5: Normalize header and footer
# Standardize spacing and ensure proper format # Standardize spacing and ensure proper format
header = re.sub(r'\s+', ' ', header_raw) header = re.sub(r"\s+", " ", header_raw)
footer = re.sub(r'\s+', ' ', footer_raw) footer = re.sub(r"\s+", " ", footer_raw)
# Ensure exactly 5 dashes on each side # Ensure exactly 5 dashes on each side
header = re.sub(r'^-+', '-----', header) header = re.sub(r"^-+", "-----", header)
header = re.sub(r'-+$', '-----', header) header = re.sub(r"-+$", "-----", header)
footer = re.sub(r'^-+', '-----', footer) footer = re.sub(r"^-+", "-----", footer)
footer = re.sub(r'-+$', '-----', footer) footer = re.sub(r"-+$", "-----", footer)
# Step 6: Clean the base64 content THOROUGHLY # Step 6: Clean the base64 content THOROUGHLY
# Remove ALL whitespace: spaces, tabs, newlines # Remove ALL whitespace: spaces, tabs, newlines
# Keep only valid base64 characters: A-Z, a-z, 0-9, +, /, = # Keep only valid base64 characters: A-Z, a-z, 0-9, +, /, =
content_clean = ''.join( content_clean = "".join(char for char in content_raw if char.isalnum() or char in "+/=")
char for char in content_raw
if char.isalnum() or char in '+/='
)
# Double-check: remove any remaining invalid characters # Double-check: remove any remaining invalid characters
content_clean = re.sub(r'[^A-Za-z0-9+/=]', '', content_clean) content_clean = re.sub(r"[^A-Za-z0-9+/=]", "", content_clean)
# Step 7: Fix base64 padding # Step 7: Fix base64 padding
# Base64 strings must be divisible by 4 # Base64 strings must be divisible by 4
remainder = len(content_clean) % 4 remainder = len(content_clean) % 4
if remainder: if remainder:
content_clean += '=' * (4 - remainder) content_clean += "=" * (4 - remainder)
# Step 8: Split into 64-character lines (PEM standard) # Step 8: Split into 64-character lines (PEM standard)
lines = [content_clean[i : i + 64] for i in range(0, len(content_clean), 64)] lines = [content_clean[i : i + 64] for i in range(0, len(content_clean), 64)]
# Step 9: Reconstruct with EXACT PEM formatting # Step 9: Reconstruct with EXACT PEM formatting
# Format: header\ncontent_line1\ncontent_line2\n...\nfooter\n # Format: header\ncontent_line1\ncontent_line2\n...\nfooter\n
return header + '\n' + '\n'.join(lines) + '\n' + footer + '\n' return header + "\n" + "\n".join(lines) + "\n" + footer + "\n"
def is_compressed(data: str) -> bool: def is_compressed(data: str) -> bool:
@@ -205,7 +204,7 @@ def can_fit_in_qr(data: str, compress: bool = False) -> bool:
if compress: if compress:
size = get_compressed_size(data) size = get_compressed_size(data)
else: else:
size = len(data.encode('utf-8')) size = len(data.encode("utf-8"))
return size <= QR_MAX_BINARY return size <= QR_MAX_BINARY
@@ -214,11 +213,7 @@ def needs_compression(data: str) -> bool:
return not can_fit_in_qr(data, compress=False) and can_fit_in_qr(data, compress=True) return not can_fit_in_qr(data, compress=False) and can_fit_in_qr(data, compress=True)
def generate_qr_code( def generate_qr_code(data: str, compress: bool = False, error_correction=None) -> bytes:
data: str,
compress: bool = False,
error_correction=None
) -> bytes:
""" """
Generate a QR code PNG from string data. Generate a QR code PNG from string data.
@@ -244,10 +239,9 @@ def generate_qr_code(
qr_data = compress_data(data) qr_data = compress_data(data)
# Check size # Check size
if len(qr_data.encode('utf-8')) > QR_MAX_BINARY: if len(qr_data.encode("utf-8")) > QR_MAX_BINARY:
raise ValueError( raise ValueError(
f"Data too large for QR code ({len(qr_data)} bytes). " f"Data too large for QR code ({len(qr_data)} bytes). " f"Maximum: {QR_MAX_BINARY} bytes"
f"Maximum: {QR_MAX_BINARY} bytes"
) )
# Use lower error correction for larger data # Use lower error correction for larger data
@@ -266,7 +260,7 @@ def generate_qr_code(
img = qr.make_image(fill_color="black", back_color="white") img = qr.make_image(fill_color="black", back_color="white")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format='PNG') img.save(buf, format="PNG")
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
@@ -294,8 +288,8 @@ def read_qr_code(image_data: bytes) -> str | None:
img: Image.Image = Image.open(io.BytesIO(image_data)) img: Image.Image = Image.open(io.BytesIO(image_data))
# Convert to RGB if necessary (pyzbar works best with RGB/grayscale) # Convert to RGB if necessary (pyzbar works best with RGB/grayscale)
if img.mode not in ('RGB', 'L'): if img.mode not in ("RGB", "L"):
img = img.convert('RGB') img = img.convert("RGB")
# Decode QR codes # Decode QR codes
decoded = pyzbar_decode(img, symbols=[ZBarSymbol.QRCODE]) decoded = pyzbar_decode(img, symbols=[ZBarSymbol.QRCODE])
@@ -304,7 +298,7 @@ def read_qr_code(image_data: bytes) -> str | None:
return None return None
# Return first QR code found # Return first QR code found
result: str = decoded[0].data.decode('utf-8') result: str = decoded[0].data.decode("utf-8")
return result return result
except Exception: except Exception:
@@ -321,7 +315,7 @@ def read_qr_code_from_file(filepath: str) -> str | None:
Returns: Returns:
Decoded string, or None if no QR code found Decoded string, or None if no QR code found
""" """
with open(filepath, 'rb') as f: with open(filepath, "rb") as f:
return read_qr_code(f.read()) return read_qr_code(f.read())
@@ -355,7 +349,7 @@ def extract_key_from_qr(image_data: bytes) -> str | None:
key_pem = qr_data key_pem = qr_data
# Step 3: Validate it looks like a PEM key # Step 3: Validate it looks like a PEM key
if '-----BEGIN' not in key_pem or '-----END' not in key_pem: if "-----BEGIN" not in key_pem or "-----END" not in key_pem:
return None return None
# Step 4: Aggressively normalize PEM format # Step 4: Aggressively normalize PEM format
@@ -367,7 +361,7 @@ def extract_key_from_qr(image_data: bytes) -> str | None:
return None return None
# Step 5: Final validation - ensure it still looks like PEM # Step 5: Final validation - ensure it still looks like PEM
if '-----BEGIN' in key_pem and '-----END' in key_pem: if "-----BEGIN" in key_pem and "-----END" in key_pem:
return key_pem return key_pem
return None return None
@@ -383,14 +377,14 @@ def extract_key_from_qr_file(filepath: str) -> str | None:
Returns: Returns:
PEM-encoded RSA key string, or None if not found/invalid PEM-encoded RSA key string, or None if not found/invalid
""" """
with open(filepath, 'rb') as f: with open(filepath, "rb") as f:
return extract_key_from_qr(f.read()) return extract_key_from_qr(f.read())
def detect_and_crop_qr( def detect_and_crop_qr(
image_data: bytes, image_data: bytes,
padding_percent: float = QR_CROP_PADDING_PERCENT, padding_percent: float = QR_CROP_PADDING_PERCENT,
min_padding_px: int = QR_CROP_MIN_PADDING_PX min_padding_px: int = QR_CROP_MIN_PADDING_PX,
) -> bytes | None: ) -> bytes | None:
""" """
Detect QR code in image and crop to it, handling rotation. Detect QR code in image and crop to it, handling rotation.
@@ -420,8 +414,8 @@ def detect_and_crop_qr(
original_mode = img.mode original_mode = img.mode
# Convert for pyzbar detection # Convert for pyzbar detection
if img.mode not in ('RGB', 'L'): if img.mode not in ("RGB", "L"):
detect_img = img.convert('RGB') detect_img = img.convert("RGB")
else: else:
detect_img = img detect_img = img
@@ -468,16 +462,17 @@ def detect_and_crop_qr(
# Convert to PNG bytes # Convert to PNG bytes
buf = io.BytesIO() buf = io.BytesIO()
# Preserve transparency if present # Preserve transparency if present
if original_mode in ('RGBA', 'LA', 'P'): if original_mode in ("RGBA", "LA", "P"):
cropped.save(buf, format='PNG') cropped.save(buf, format="PNG")
else: else:
cropped.save(buf, format='PNG') cropped.save(buf, format="PNG")
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
except Exception as e: except Exception as e:
# Log for debugging but return None for clean API # Log for debugging but return None for clean API
import sys import sys
print(f"QR crop error: {e}", file=sys.stderr) print(f"QR crop error: {e}", file=sys.stderr)
return None return None
@@ -485,7 +480,7 @@ def detect_and_crop_qr(
def detect_and_crop_qr_file( def detect_and_crop_qr_file(
filepath: str, filepath: str,
padding_percent: float = QR_CROP_PADDING_PERCENT, padding_percent: float = QR_CROP_PADDING_PERCENT,
min_padding_px: int = QR_CROP_MIN_PADDING_PX min_padding_px: int = QR_CROP_MIN_PADDING_PX,
) -> bytes | None: ) -> bytes | None:
""" """
Detect QR code in image file and crop to it. Detect QR code in image file and crop to it.
@@ -498,7 +493,7 @@ def detect_and_crop_qr_file(
Returns: Returns:
Cropped PNG image bytes, or None if no QR code found Cropped PNG image bytes, or None if no QR code found
""" """
with open(filepath, 'rb') as f: with open(filepath, "rb") as f:
return detect_and_crop_qr(f.read(), padding_percent, min_padding_px) return detect_and_crop_qr(f.read(), padding_percent, min_padding_px)

View File

@@ -40,21 +40,21 @@ from .exceptions import CapacityError, EmbeddingError
from .models import EmbedStats, FilePayload from .models import EmbedStats, FilePayload
# Lossless formats that preserve LSB data # Lossless formats that preserve LSB data
LOSSLESS_FORMATS = {'PNG', 'BMP', 'TIFF'} LOSSLESS_FORMATS = {"PNG", "BMP", "TIFF"}
# Format to extension mapping # Format to extension mapping
FORMAT_TO_EXT = { FORMAT_TO_EXT = {
'PNG': 'png', "PNG": "png",
'BMP': 'bmp', "BMP": "bmp",
'TIFF': 'tiff', "TIFF": "tiff",
} }
# Extension to PIL format mapping # Extension to PIL format mapping
EXT_TO_FORMAT = { EXT_TO_FORMAT = {
'png': 'PNG', "png": "PNG",
'bmp': 'BMP', "bmp": "BMP",
'tiff': 'TIFF', "tiff": "TIFF",
'tif': 'TIFF', "tif": "TIFF",
} }
# ============================================================================= # =============================================================================
@@ -78,12 +78,12 @@ LENGTH_PREFIX = 4 # 4 bytes for payload length in LSB embedding
ENCRYPTION_OVERHEAD = HEADER_OVERHEAD + LENGTH_PREFIX # 70 bytes total ENCRYPTION_OVERHEAD = HEADER_OVERHEAD + LENGTH_PREFIX # 70 bytes total
# DCT output format options (v3.0.1) # DCT output format options (v3.0.1)
DCT_OUTPUT_PNG = 'png' DCT_OUTPUT_PNG = "png"
DCT_OUTPUT_JPEG = 'jpeg' DCT_OUTPUT_JPEG = "jpeg"
# DCT color mode options (v3.0.1) # DCT color mode options (v3.0.1)
DCT_COLOR_GRAYSCALE = 'grayscale' DCT_COLOR_GRAYSCALE = "grayscale"
DCT_COLOR_COLOR = 'color' DCT_COLOR_COLOR = "color"
# ============================================================================= # =============================================================================
@@ -98,6 +98,7 @@ def _get_dct_module():
global _dct_module global _dct_module
if _dct_module is None: if _dct_module is None:
from . import dct_steganography from . import dct_steganography
_dct_module = dct_steganography _dct_module = dct_steganography
return _dct_module return _dct_module
@@ -124,6 +125,7 @@ def has_dct_support() -> bool:
# FORMAT UTILITIES # FORMAT UTILITIES
# ============================================================================= # =============================================================================
def get_output_format(input_format: str | None) -> tuple[str, str]: def get_output_format(input_format: str | None) -> tuple[str, str]:
""" """
Determine the output format based on input format. Determine the output format based on input format.
@@ -135,23 +137,25 @@ def get_output_format(input_format: str | None) -> tuple[str, str]:
Tuple of (PIL format string, file extension) for output Tuple of (PIL format string, file extension) for output
Falls back to PNG for lossy or unknown formats. Falls back to PNG for lossy or unknown formats.
""" """
debug.validate(input_format is None or isinstance(input_format, str), debug.validate(
"Input format must be string or None") input_format is None or isinstance(input_format, str), "Input format must be string or None"
)
if input_format and input_format.upper() in LOSSLESS_FORMATS: if input_format and input_format.upper() in LOSSLESS_FORMATS:
fmt = input_format.upper() fmt = input_format.upper()
ext = FORMAT_TO_EXT.get(fmt, 'png') ext = FORMAT_TO_EXT.get(fmt, "png")
debug.print(f"Using lossless format: {fmt} -> .{ext}") debug.print(f"Using lossless format: {fmt} -> .{ext}")
return fmt, ext return fmt, ext
debug.print(f"Input format {input_format} is lossy or unknown, defaulting to PNG") debug.print(f"Input format {input_format} is lossy or unknown, defaulting to PNG")
return 'PNG', 'png' return "PNG", "png"
# ============================================================================= # =============================================================================
# CAPACITY FUNCTIONS # CAPACITY FUNCTIONS
# ============================================================================= # =============================================================================
def will_fit( def will_fit(
payload: str | bytes | FilePayload | int, payload: str | bytes | FilePayload | int,
carrier_image: bytes, carrier_image: bytes,
@@ -175,12 +179,12 @@ def will_fit(
payload_size = payload payload_size = payload
payload_data = None payload_data = None
elif isinstance(payload, str): elif isinstance(payload, str):
payload_data = payload.encode('utf-8') payload_data = payload.encode("utf-8")
payload_size = len(payload_data) payload_size = len(payload_data)
elif isinstance(payload, FilePayload): elif isinstance(payload, FilePayload):
payload_data = payload.data payload_data = payload.data
filename_overhead = len(payload.filename.encode('utf-8')) if payload.filename else 0 filename_overhead = len(payload.filename.encode("utf-8")) if payload.filename else 0
mime_overhead = len(payload.mime_type.encode('utf-8')) if payload.mime_type else 0 mime_overhead = len(payload.mime_type.encode("utf-8")) if payload.mime_type else 0
payload_size = len(payload.data) + filename_overhead + mime_overhead + 5 payload_size = len(payload.data) + filename_overhead + mime_overhead + 5
else: else:
payload_data = payload payload_data = payload
@@ -198,6 +202,7 @@ def will_fit(
if include_compression_estimate and payload_data is not None and len(payload_data) >= 64: if include_compression_estimate and payload_data is not None and len(payload_data) >= 64:
try: try:
import zlib import zlib
compressed = zlib.compress(payload_data, level=6) compressed = zlib.compress(payload_data, level=6)
compressed_size = len(compressed) + 9 # Compression header compressed_size = len(compressed) + 9 # Compression header
if compressed_size < payload_size: if compressed_size < payload_size:
@@ -211,14 +216,14 @@ def will_fit(
usage_percent = (estimated_encrypted_size / capacity * 100) if capacity > 0 else 100.0 usage_percent = (estimated_encrypted_size / capacity * 100) if capacity > 0 else 100.0
return { return {
'fits': fits, "fits": fits,
'payload_size': payload_size, "payload_size": payload_size,
'estimated_encrypted_size': estimated_encrypted_size, "estimated_encrypted_size": estimated_encrypted_size,
'capacity': capacity, "capacity": capacity,
'usage_percent': min(usage_percent, 100.0), "usage_percent": min(usage_percent, 100.0),
'headroom': headroom, "headroom": headroom,
'compressed_estimate': compressed_estimate, "compressed_estimate": compressed_estimate,
'mode': EMBED_MODE_LSB, "mode": EMBED_MODE_LSB,
} }
@@ -233,8 +238,9 @@ def calculate_capacity(image_data: bytes, bits_per_channel: int = 1) -> int:
Returns: Returns:
Maximum bytes that can be embedded (minus overhead) Maximum bytes that can be embedded (minus overhead)
""" """
debug.validate(bits_per_channel in (1, 2), debug.validate(
f"bits_per_channel must be 1 or 2, got {bits_per_channel}") bits_per_channel in (1, 2), f"bits_per_channel must be 1 or 2, got {bits_per_channel}"
)
img_file = Image.open(io.BytesIO(image_data)) img_file = Image.open(io.BytesIO(image_data))
try: try:
@@ -273,12 +279,12 @@ def calculate_capacity_by_mode(
dct_info = dct_mod.calculate_dct_capacity(image_data) dct_info = dct_mod.calculate_dct_capacity(image_data)
return { return {
'mode': EMBED_MODE_DCT, "mode": EMBED_MODE_DCT,
'capacity_bytes': dct_info.usable_capacity_bytes, "capacity_bytes": dct_info.usable_capacity_bytes,
'capacity_bits': dct_info.total_capacity_bits, "capacity_bits": dct_info.total_capacity_bits,
'width': dct_info.width, "width": dct_info.width,
'height': dct_info.height, "height": dct_info.height,
'total_blocks': dct_info.total_blocks, "total_blocks": dct_info.total_blocks,
} }
else: else:
capacity = calculate_capacity(image_data, bits_per_channel) capacity = calculate_capacity(image_data, bits_per_channel)
@@ -289,12 +295,12 @@ def calculate_capacity_by_mode(
img.close() img.close()
return { return {
'mode': EMBED_MODE_LSB, "mode": EMBED_MODE_LSB,
'capacity_bytes': capacity, "capacity_bytes": capacity,
'capacity_bits': capacity * 8, "capacity_bits": capacity * 8,
'width': width, "width": width,
'height': height, "height": height,
'bits_per_channel': bits_per_channel, "bits_per_channel": bits_per_channel,
} }
@@ -318,13 +324,13 @@ def will_fit_by_mode(
""" """
if embed_mode == EMBED_MODE_DCT: if embed_mode == EMBED_MODE_DCT:
if not has_dct_support(): if not has_dct_support():
return {'fits': False, 'error': 'scipy not available', 'mode': EMBED_MODE_DCT} return {"fits": False, "error": "scipy not available", "mode": EMBED_MODE_DCT}
if isinstance(payload, int): if isinstance(payload, int):
payload_size = payload payload_size = payload
elif isinstance(payload, str): elif isinstance(payload, str):
payload_size = len(payload.encode('utf-8')) payload_size = len(payload.encode("utf-8"))
elif hasattr(payload, 'data'): elif hasattr(payload, "data"):
payload_size = len(payload.data) payload_size = len(payload.data)
else: else:
payload_size = len(payload) payload_size = len(payload)
@@ -339,12 +345,12 @@ def will_fit_by_mode(
usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0 usage_percent = (estimated_size / capacity * 100) if capacity > 0 else 100.0
return { return {
'fits': fits, "fits": fits,
'payload_size': payload_size, "payload_size": payload_size,
'capacity': capacity, "capacity": capacity,
'usage_percent': min(usage_percent, 100.0), "usage_percent": min(usage_percent, 100.0),
'headroom': capacity - estimated_size, "headroom": capacity - estimated_size,
'mode': EMBED_MODE_DCT, "mode": EMBED_MODE_DCT,
} }
else: else:
return will_fit(payload, carrier_image, bits_per_channel) return will_fit(payload, carrier_image, bits_per_channel)
@@ -359,17 +365,17 @@ def get_available_modes() -> dict:
""" """
return { return {
EMBED_MODE_LSB: { EMBED_MODE_LSB: {
'available': True, "available": True,
'name': 'Spatial LSB', "name": "Spatial LSB",
'description': 'Embed in pixel LSBs, outputs PNG/BMP', "description": "Embed in pixel LSBs, outputs PNG/BMP",
'output_format': 'PNG (color)', "output_format": "PNG (color)",
}, },
EMBED_MODE_DCT: { EMBED_MODE_DCT: {
'available': has_dct_support(), "available": has_dct_support(),
'name': 'DCT Domain', "name": "DCT Domain",
'description': 'Embed in DCT coefficients, outputs grayscale PNG or JPEG', "description": "Embed in DCT coefficients, outputs grayscale PNG or JPEG",
'output_formats': ['PNG (grayscale)', 'JPEG (grayscale)'], "output_formats": ["PNG (grayscale)", "JPEG (grayscale)"],
'requires': 'scipy', "requires": "scipy",
}, },
} }
@@ -403,20 +409,20 @@ def compare_modes(image_data: bytes) -> dict:
dct_available = False dct_available = False
return { return {
'width': width, "width": width,
'height': height, "height": height,
'lsb': { "lsb": {
'capacity_bytes': lsb_bytes, "capacity_bytes": lsb_bytes,
'capacity_kb': lsb_bytes / 1024, "capacity_kb": lsb_bytes / 1024,
'available': True, "available": True,
'output': 'PNG (color)', "output": "PNG (color)",
}, },
'dct': { "dct": {
'capacity_bytes': dct_bytes, "capacity_bytes": dct_bytes,
'capacity_kb': dct_bytes / 1024, "capacity_kb": dct_bytes / 1024,
'available': dct_available, "available": dct_available,
'output': 'PNG or JPEG (grayscale)', "output": "PNG or JPEG (grayscale)",
'ratio_vs_lsb': (dct_bytes / lsb_bytes * 100) if lsb_bytes > 0 else 0, "ratio_vs_lsb": (dct_bytes / lsb_bytes * 100) if lsb_bytes > 0 else 0,
}, },
} }
@@ -425,6 +431,7 @@ def compare_modes(image_data: bytes) -> dict:
# PIXEL INDEX GENERATION # PIXEL INDEX GENERATION
# ============================================================================= # =============================================================================
@debug.time @debug.time
def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list[int]: def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list[int]:
""" """
@@ -436,23 +443,24 @@ def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list
debug.validate(len(key) == 32, f"Pixel key must be 32 bytes, got {len(key)}") debug.validate(len(key) == 32, f"Pixel key must be 32 bytes, got {len(key)}")
debug.validate(num_pixels > 0, f"Number of pixels must be positive, got {num_pixels}") debug.validate(num_pixels > 0, f"Number of pixels must be positive, got {num_pixels}")
debug.validate(num_needed > 0, f"Number needed must be positive, got {num_needed}") debug.validate(num_needed > 0, f"Number needed must be positive, got {num_needed}")
debug.validate(num_needed <= num_pixels, debug.validate(
f"Cannot select {num_needed} pixels from {num_pixels} available") num_needed <= num_pixels, f"Cannot select {num_needed} pixels from {num_pixels} available"
)
debug.print(f"Generating {num_needed} pixel indices from {num_pixels} total pixels") debug.print(f"Generating {num_needed} pixel indices from {num_pixels} total pixels")
if num_needed >= num_pixels // 2: if num_needed >= num_pixels // 2:
debug.print(f"Using full shuffle (needed {num_needed}/{num_pixels} pixels)") debug.print(f"Using full shuffle (needed {num_needed}/{num_pixels} pixels)")
nonce = b'\x00' * 16 nonce = b"\x00" * 16
cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend()) cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend())
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
indices = list(range(num_pixels)) indices = list(range(num_pixels))
random_bytes = encryptor.update(b'\x00' * (num_pixels * 4)) random_bytes = encryptor.update(b"\x00" * (num_pixels * 4))
for i in range(num_pixels - 1, 0, -1): for i in range(num_pixels - 1, 0, -1):
j_bytes = random_bytes[(num_pixels - 1 - i) * 4 : (num_pixels - i) * 4] j_bytes = random_bytes[(num_pixels - 1 - i) * 4 : (num_pixels - i) * 4]
j = int.from_bytes(j_bytes, 'big') % (i + 1) j = int.from_bytes(j_bytes, "big") % (i + 1)
indices[i], indices[j] = indices[j], indices[i] indices[i], indices[j] = indices[j], indices[i]
selected = indices[:num_needed] selected = indices[:num_needed]
@@ -463,17 +471,17 @@ def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list
selected = [] selected = []
used = set() used = set()
nonce = b'\x00' * 16 nonce = b"\x00" * 16
cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend()) cipher = Cipher(algorithms.ChaCha20(key, nonce), mode=None, backend=default_backend())
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
bytes_needed = (num_needed * 2) * 4 bytes_needed = (num_needed * 2) * 4
random_bytes = encryptor.update(b'\x00' * bytes_needed) random_bytes = encryptor.update(b"\x00" * bytes_needed)
byte_offset = 0 byte_offset = 0
collisions = 0 collisions = 0
while len(selected) < num_needed and byte_offset < len(random_bytes) - 4: while len(selected) < num_needed and byte_offset < len(random_bytes) - 4:
idx = int.from_bytes(random_bytes[byte_offset:byte_offset + 4], 'big') % num_pixels idx = int.from_bytes(random_bytes[byte_offset : byte_offset + 4], "big") % num_pixels
byte_offset += 4 byte_offset += 4
if idx not in used: if idx not in used:
@@ -486,8 +494,8 @@ def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list
debug.print(f"Need {num_needed - len(selected)} more indices, generating...") debug.print(f"Need {num_needed - len(selected)} more indices, generating...")
extra_needed = num_needed - len(selected) extra_needed = num_needed - len(selected)
for _ in range(extra_needed * 2): for _ in range(extra_needed * 2):
extra_bytes = encryptor.update(b'\x00' * 4) extra_bytes = encryptor.update(b"\x00" * 4)
idx = int.from_bytes(extra_bytes, 'big') % num_pixels idx = int.from_bytes(extra_bytes, "big") % num_pixels
if idx not in used: if idx not in used:
used.add(idx) used.add(idx)
selected.append(idx) selected.append(idx)
@@ -495,8 +503,10 @@ def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list
break break
debug.print(f"Generated {len(selected)} indices with {collisions} collisions") debug.print(f"Generated {len(selected)} indices with {collisions} collisions")
debug.validate(len(selected) == num_needed, debug.validate(
f"Failed to generate enough indices: {len(selected)}/{num_needed}") len(selected) == num_needed,
f"Failed to generate enough indices: {len(selected)}/{num_needed}",
)
return selected return selected
@@ -504,6 +514,7 @@ def generate_pixel_indices(key: bytes, num_pixels: int, num_needed: int) -> list
# EMBEDDING FUNCTIONS # EMBEDDING FUNCTIONS
# ============================================================================= # =============================================================================
@debug.time @debug.time
def embed_in_image( def embed_in_image(
data: bytes, data: bytes,
@@ -513,8 +524,8 @@ def embed_in_image(
output_format: str | None = None, output_format: str | None = None,
embed_mode: str = EMBED_MODE_LSB, embed_mode: str = EMBED_MODE_LSB,
dct_output_format: str = DCT_OUTPUT_PNG, dct_output_format: str = DCT_OUTPUT_PNG,
dct_color_mode: str = 'grayscale', dct_color_mode: str = "grayscale",
) -> tuple[bytes, Union[EmbedStats, 'DCTEmbedStats'], str]: ) -> tuple[bytes, Union[EmbedStats, "DCTEmbedStats"], str]:
""" """
Embed data into an image using specified mode. Embed data into an image using specified mode.
@@ -537,15 +548,15 @@ def embed_in_image(
ImportError: If DCT mode requested but scipy unavailable ImportError: If DCT mode requested but scipy unavailable
""" """
debug.print(f"embed_in_image: mode={embed_mode}, data={len(data)} bytes") debug.print(f"embed_in_image: mode={embed_mode}, data={len(data)} bytes")
debug.validate(embed_mode in VALID_EMBED_MODES, debug.validate(
f"Invalid embed_mode: {embed_mode}. Use 'lsb' or 'dct'") embed_mode in VALID_EMBED_MODES, f"Invalid embed_mode: {embed_mode}. Use 'lsb' or 'dct'"
)
# DCT MODE # DCT MODE
if embed_mode == EMBED_MODE_DCT: if embed_mode == EMBED_MODE_DCT:
if not has_dct_support(): if not has_dct_support():
raise ImportError( raise ImportError(
"scipy is required for DCT embedding mode. " "scipy is required for DCT embedding mode. " "Install with: pip install scipy"
"Install with: pip install scipy"
) )
# Validate DCT output format # Validate DCT output format
@@ -554,9 +565,9 @@ def embed_in_image(
dct_output_format = DCT_OUTPUT_PNG dct_output_format = DCT_OUTPUT_PNG
# Validate DCT color mode (v3.0.1) # Validate DCT color mode (v3.0.1)
if dct_color_mode not in ('grayscale', 'color'): if dct_color_mode not in ("grayscale", "color"):
debug.print(f"Invalid dct_color_mode '{dct_color_mode}', defaulting to grayscale") debug.print(f"Invalid dct_color_mode '{dct_color_mode}', defaulting to grayscale")
dct_color_mode = 'grayscale' dct_color_mode = "grayscale"
dct_mod = _get_dct_module() dct_mod = _get_dct_module()
@@ -571,12 +582,14 @@ def embed_in_image(
# Determine extension based on output format # Determine extension based on output format
if dct_output_format == DCT_OUTPUT_JPEG: if dct_output_format == DCT_OUTPUT_JPEG:
ext = 'jpg' ext = "jpg"
else: else:
ext = 'png' ext = "png"
debug.print(f"DCT embedding complete: {dct_output_format.upper()} output, " debug.print(
f"color_mode={dct_color_mode}, ext={ext}") f"DCT embedding complete: {dct_output_format.upper()} output, "
f"color_mode={dct_color_mode}, ext={ext}"
)
return stego_bytes, dct_stats, ext return stego_bytes, dct_stats, ext
# LSB MODE # LSB MODE
@@ -595,10 +608,10 @@ def _embed_lsb(
""" """
debug.print(f"LSB embedding {len(data)} bytes into image") debug.print(f"LSB embedding {len(data)} bytes into image")
debug.data(pixel_key, "Pixel key for embedding") debug.data(pixel_key, "Pixel key for embedding")
debug.validate(bits_per_channel in (1, 2), debug.validate(
f"bits_per_channel must be 1 or 2, got {bits_per_channel}") bits_per_channel in (1, 2), f"bits_per_channel must be 1 or 2, got {bits_per_channel}"
debug.validate(len(pixel_key) == 32, )
f"Pixel key must be 32 bytes, got {len(pixel_key)}") debug.validate(len(pixel_key) == 32, f"Pixel key must be 32 bytes, got {len(pixel_key)}")
img_file = None img_file = None
img = None img = None
@@ -610,8 +623,8 @@ def _embed_lsb(
debug.print(f"Carrier image: {img_file.size[0]}x{img_file.size[1]}, format: {input_format}") debug.print(f"Carrier image: {img_file.size[0]}x{img_file.size[1]}, format: {input_format}")
img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy() img = img_file.convert("RGB") if img_file.mode != "RGB" else img_file.copy()
if img_file.mode != 'RGB': if img_file.mode != "RGB":
debug.print(f"Converting image from {img_file.mode} to RGB") debug.print(f"Converting image from {img_file.mode} to RGB")
pixels = list(img.getdata()) pixels = list(img.getdata())
@@ -622,16 +635,18 @@ def _embed_lsb(
debug.print(f"Image capacity: {max_bytes} bytes at {bits_per_channel} bit(s)/channel") debug.print(f"Image capacity: {max_bytes} bytes at {bits_per_channel} bit(s)/channel")
data_with_len = struct.pack('>I', len(data)) + data data_with_len = struct.pack(">I", len(data)) + data
if len(data_with_len) > max_bytes: if len(data_with_len) > max_bytes:
debug.print(f"Capacity error: need {len(data_with_len)}, have {max_bytes}") debug.print(f"Capacity error: need {len(data_with_len)}, have {max_bytes}")
raise CapacityError(len(data_with_len), max_bytes) raise CapacityError(len(data_with_len), max_bytes)
debug.print(f"Total data to embed: {len(data_with_len)} bytes " debug.print(
f"({len(data_with_len)/max_bytes*100:.1f}% of capacity)") f"Total data to embed: {len(data_with_len)} bytes "
f"({len(data_with_len)/max_bytes*100:.1f}% of capacity)"
)
binary_data = ''.join(format(b, '08b') for b in data_with_len) binary_data = "".join(format(b, "08b") for b in data_with_len)
pixels_needed = (len(binary_data) + bits_per_pixel - 1) // bits_per_pixel pixels_needed = (len(binary_data) + bits_per_pixel - 1) // bits_per_pixel
debug.print(f"Need {pixels_needed} pixels to embed {len(binary_data)} bits") debug.print(f"Need {pixels_needed} pixels to embed {len(binary_data)} bits")
@@ -654,7 +669,9 @@ def _embed_lsb(
for channel_idx, channel_val in enumerate([r, g, b]): for channel_idx, channel_val in enumerate([r, g, b]):
if bit_idx >= len(binary_data): if bit_idx >= len(binary_data):
break break
bits = binary_data[bit_idx:bit_idx + bits_per_channel].ljust(bits_per_channel, '0') bits = binary_data[bit_idx : bit_idx + bits_per_channel].ljust(
bits_per_channel, "0"
)
new_val = (channel_val & clear_mask) | int(bits, 2) new_val = (channel_val & clear_mask) | int(bits, 2)
if channel_val != new_val: if channel_val != new_val:
@@ -674,12 +691,12 @@ def _embed_lsb(
debug.print(f"Modified {modified_pixels} pixels (out of {len(selected_indices)} selected)") debug.print(f"Modified {modified_pixels} pixels (out of {len(selected_indices)} selected)")
stego_img = Image.new('RGB', img.size) stego_img = Image.new("RGB", img.size)
stego_img.putdata(new_pixels) stego_img.putdata(new_pixels)
if output_format: if output_format:
out_fmt = output_format.upper() out_fmt = output_format.upper()
out_ext = FORMAT_TO_EXT.get(out_fmt, 'png') out_ext = FORMAT_TO_EXT.get(out_fmt, "png")
debug.print(f"Using forced output format: {out_fmt}") debug.print(f"Using forced output format: {out_fmt}")
else: else:
out_fmt, out_ext = get_output_format(input_format) out_fmt, out_ext = get_output_format(input_format)
@@ -693,7 +710,7 @@ def _embed_lsb(
pixels_modified=modified_pixels, pixels_modified=modified_pixels,
total_pixels=num_pixels, total_pixels=num_pixels,
capacity_used=len(data_with_len) / max_bytes, capacity_used=len(data_with_len) / max_bytes,
bytes_embedded=len(data_with_len) bytes_embedded=len(data_with_len),
) )
debug.print(f"LSB embedding complete: {out_fmt} image, {len(output.getvalue())} bytes") debug.print(f"LSB embedding complete: {out_fmt} image, {len(output.getvalue())} bytes")
@@ -718,6 +735,7 @@ def _embed_lsb(
# EXTRACTION FUNCTIONS # EXTRACTION FUNCTIONS
# ============================================================================= # =============================================================================
@debug.time @debug.time
def extract_from_image( def extract_from_image(
image_data: bytes, image_data: bytes,
@@ -777,18 +795,15 @@ def _extract_dct(image_data: bytes, pixel_key: bytes) -> bytes | None:
return None return None
def _extract_lsb( def _extract_lsb(image_data: bytes, pixel_key: bytes, bits_per_channel: int = 1) -> bytes | None:
image_data: bytes,
pixel_key: bytes,
bits_per_channel: int = 1
) -> bytes | None:
""" """
Extract using LSB mode (internal implementation). Extract using LSB mode (internal implementation).
""" """
debug.print(f"LSB extracting from {len(image_data)} byte image") debug.print(f"LSB extracting from {len(image_data)} byte image")
debug.data(pixel_key, "Pixel key for extraction") debug.data(pixel_key, "Pixel key for extraction")
debug.validate(bits_per_channel in (1, 2), debug.validate(
f"bits_per_channel must be 1 or 2, got {bits_per_channel}") bits_per_channel in (1, 2), f"bits_per_channel must be 1 or 2, got {bits_per_channel}"
)
img_file = None img_file = None
img = None img = None
@@ -797,8 +812,8 @@ def _extract_lsb(
img_file = Image.open(io.BytesIO(image_data)) img_file = Image.open(io.BytesIO(image_data))
debug.print(f"Image: {img_file.size[0]}x{img_file.size[1]}, format: {img_file.format}") debug.print(f"Image: {img_file.size[0]}x{img_file.size[1]}, format: {img_file.format}")
img = img_file.convert('RGB') if img_file.mode != 'RGB' else img_file.copy() img = img_file.convert("RGB") if img_file.mode != "RGB" else img_file.copy()
if img_file.mode != 'RGB': if img_file.mode != "RGB":
debug.print(f"Converting image from {img_file.mode} to RGB") debug.print(f"Converting image from {img_file.mode} to RGB")
pixels = list(img.getdata()) pixels = list(img.getdata())
@@ -812,7 +827,7 @@ def _extract_lsb(
initial_indices = generate_pixel_indices(pixel_key, num_pixels, initial_pixels) initial_indices = generate_pixel_indices(pixel_key, num_pixels, initial_pixels)
binary_data = '' binary_data = ""
for pixel_idx in initial_indices: for pixel_idx in initial_indices:
r, g, b = pixels[pixel_idx] r, g, b = pixels[pixel_idx]
for channel in [r, g, b]: for channel in [r, g, b]:
@@ -825,7 +840,7 @@ def _extract_lsb(
debug.print(f"Not enough bits for length: {len(length_bits)}/32") debug.print(f"Not enough bits for length: {len(length_bits)}/32")
return None return None
data_length = struct.unpack('>I', int(length_bits, 2).to_bytes(4, 'big'))[0] data_length = struct.unpack(">I", int(length_bits, 2).to_bytes(4, "big"))[0]
debug.print(f"Extracted length: {data_length} bytes") debug.print(f"Extracted length: {data_length} bytes")
except Exception as e: except Exception as e:
debug.print(f"Failed to parse length: {e}") debug.print(f"Failed to parse length: {e}")
@@ -843,7 +858,7 @@ def _extract_lsb(
selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed) selected_indices = generate_pixel_indices(pixel_key, num_pixels, pixels_needed)
binary_data = '' binary_data = ""
for pixel_idx in selected_indices: for pixel_idx in selected_indices:
r, g, b = pixels[pixel_idx] r, g, b = pixels[pixel_idx]
for channel in [r, g, b]: for channel in [r, g, b]:
@@ -880,6 +895,7 @@ def _extract_lsb(
# UTILITY FUNCTIONS # UTILITY FUNCTIONS
# ============================================================================= # =============================================================================
def get_image_dimensions(image_data: bytes) -> tuple[int, int]: def get_image_dimensions(image_data: bytes) -> tuple[int, int]:
"""Get image dimensions without loading full image.""" """Get image dimensions without loading full image."""
debug.validate(len(image_data) > 0, "Image data cannot be empty") debug.validate(len(image_data) > 0, "Image data cannot be empty")

View File

@@ -18,7 +18,7 @@ from .constants import DAY_NAMES
from .debug import debug from .debug import debug
def strip_image_metadata(image_data: bytes, output_format: str = 'PNG') -> bytes: def strip_image_metadata(image_data: bytes, output_format: str = "PNG") -> bytes:
""" """
Remove all metadata (EXIF, ICC profiles, etc.) from an image. Remove all metadata (EXIF, ICC profiles, etc.) from an image.
@@ -41,8 +41,8 @@ def strip_image_metadata(image_data: bytes, output_format: str = 'PNG') -> bytes
img = Image.open(io.BytesIO(image_data)) img = Image.open(io.BytesIO(image_data))
# Convert to RGB if needed (handles RGBA, P, L, etc.) # Convert to RGB if needed (handles RGBA, P, L, etc.)
if img.mode not in ('RGB', 'RGBA'): if img.mode not in ("RGB", "RGBA"):
img = img.convert('RGB') img = img.convert("RGB")
# Create fresh image - this discards all metadata # Create fresh image - this discards all metadata
clean = Image.new(img.mode, img.size) clean = Image.new(img.mode, img.size)
@@ -56,11 +56,7 @@ def strip_image_metadata(image_data: bytes, output_format: str = 'PNG') -> bytes
return output.getvalue() return output.getvalue()
def generate_filename( def generate_filename(date_str: str | None = None, prefix: str = "", extension: str = "png") -> str:
date_str: str | None = None,
prefix: str = "",
extension: str = "png"
) -> str:
""" """
Generate a filename for stego images. Generate a filename for stego images.
@@ -78,17 +74,19 @@ def generate_filename(
>>> generate_filename("2023-12-25", "secret_", "png") >>> generate_filename("2023-12-25", "secret_", "png")
"secret_a1b2c3d4_20231225.png" "secret_a1b2c3d4_20231225.png"
""" """
debug.validate(bool(extension) and '.' not in extension, debug.validate(
f"Extension must not contain dot, got '{extension}'") bool(extension) and "." not in extension,
f"Extension must not contain dot, got '{extension}'",
)
if date_str is None: if date_str is None:
date_str = date.today().isoformat() date_str = date.today().isoformat()
date_compact = date_str.replace('-', '') date_compact = date_str.replace("-", "")
random_hex = secrets.token_hex(4) random_hex = secrets.token_hex(4)
# Ensure extension doesn't have a leading dot # Ensure extension doesn't have a leading dot
extension = extension.lstrip('.') extension = extension.lstrip(".")
filename = f"{prefix}{random_hex}_{date_compact}.{extension}" filename = f"{prefix}{random_hex}_{date_compact}.{extension}"
debug.print(f"Generated filename: {filename}") debug.print(f"Generated filename: {filename}")
@@ -114,7 +112,7 @@ def parse_date_from_filename(filename: str) -> str | None:
import re import re
# Try YYYYMMDD format # Try YYYYMMDD format
match = re.search(r'_(\d{4})(\d{2})(\d{2})(?:\.|$)', filename) match = re.search(r"_(\d{4})(\d{2})(\d{2})(?:\.|$)", filename)
if match: if match:
year, month, day = match.groups() year, month, day = match.groups()
date_str = f"{year}-{month}-{day}" date_str = f"{year}-{month}-{day}"
@@ -122,7 +120,7 @@ def parse_date_from_filename(filename: str) -> str | None:
return date_str return date_str
# Try YYYY-MM-DD format # Try YYYY-MM-DD format
match = re.search(r'_(\d{4})-(\d{2})-(\d{2})(?:\.|$)', filename) match = re.search(r"_(\d{4})-(\d{2})-(\d{2})(?:\.|$)", filename)
if match: if match:
year, month, day = match.groups() year, month, day = match.groups()
date_str = f"{year}-{month}-{day}" date_str = f"{year}-{month}-{day}"
@@ -147,11 +145,13 @@ def get_day_from_date(date_str: str) -> str:
>>> get_day_from_date("2023-12-25") >>> get_day_from_date("2023-12-25")
"Monday" "Monday"
""" """
debug.validate(len(date_str) == 10 and date_str[4] == '-' and date_str[7] == '-', debug.validate(
f"Invalid date format: {date_str}, expected YYYY-MM-DD") len(date_str) == 10 and date_str[4] == "-" and date_str[7] == "-",
f"Invalid date format: {date_str}, expected YYYY-MM-DD",
)
try: try:
year, month, day = map(int, date_str.split('-')) year, month, day = map(int, date_str.split("-"))
d = date(year, month, day) d = date(year, month, day)
day_name = DAY_NAMES[d.weekday()] day_name = DAY_NAMES[d.weekday()]
debug.print(f"Date {date_str} is {day_name}") debug.print(f"Date {date_str} is {day_name}")
@@ -231,11 +231,11 @@ class SecureDeleter:
debug.print("File is empty, nothing to overwrite") debug.print("File is empty, nothing to overwrite")
return return
patterns = [b'\x00', b'\xFF', bytes([random.randint(0, 255)])] patterns = [b"\x00", b"\xff", bytes([random.randint(0, 255)])]
for pass_num in range(self.passes): for pass_num in range(self.passes):
debug.print(f"Overwrite pass {pass_num + 1}/{self.passes}") debug.print(f"Overwrite pass {pass_num + 1}/{self.passes}")
with open(file_path, 'r+b') as f: with open(file_path, "r+b") as f:
for pattern_idx, pattern in enumerate(patterns): for pattern_idx, pattern in enumerate(patterns):
f.seek(0) f.seek(0)
# Write pattern in chunks for large files # Write pattern in chunks for large files
@@ -271,7 +271,7 @@ class SecureDeleter:
# First, securely overwrite all files # First, securely overwrite all files
file_count = 0 file_count = 0
for file_path in self.path.rglob('*'): for file_path in self.path.rglob("*"):
if file_path.is_file(): if file_path.is_file():
self._overwrite_file(file_path) self._overwrite_file(file_path)
file_count += 1 file_count += 1
@@ -325,9 +325,9 @@ def format_file_size(size_bytes: int) -> str:
debug.validate(size_bytes >= 0, f"File size cannot be negative: {size_bytes}") debug.validate(size_bytes >= 0, f"File size cannot be negative: {size_bytes}")
size: float = float(size_bytes) size: float = float(size_bytes)
for unit in ['B', 'KB', 'MB', 'GB']: for unit in ["B", "KB", "MB", "GB"]:
if size < 1024: if size < 1024:
if unit == 'B': if unit == "B":
return f"{int(size)} {unit}" return f"{int(size)} {unit}"
return f"{size:.1f} {unit}" return f"{size:.1f} {unit}"
size /= 1024 size /= 1024

View File

@@ -66,11 +66,9 @@ def validate_pin(pin: str, required: bool = False) -> ValidationResult:
return ValidationResult.error("PIN must contain only digits") return ValidationResult.error("PIN must contain only digits")
if len(pin) < MIN_PIN_LENGTH or len(pin) > MAX_PIN_LENGTH: if len(pin) < MIN_PIN_LENGTH or len(pin) > MAX_PIN_LENGTH:
return ValidationResult.error( return ValidationResult.error(f"PIN must be {MIN_PIN_LENGTH}-{MAX_PIN_LENGTH} digits")
f"PIN must be {MIN_PIN_LENGTH}-{MAX_PIN_LENGTH} digits"
)
if pin[0] == '0': if pin[0] == "0":
return ValidationResult.error("PIN cannot start with zero") return ValidationResult.error("PIN cannot start with zero")
return ValidationResult.ok(length=len(pin)) return ValidationResult.ok(length=len(pin))
@@ -121,9 +119,7 @@ def validate_payload(payload: str | bytes | FilePayload) -> ValidationResult:
) )
return ValidationResult.ok( return ValidationResult.ok(
size=len(payload.data), size=len(payload.data), filename=payload.filename, mime_type=payload.mime_type
filename=payload.filename,
mime_type=payload.mime_type
) )
elif isinstance(payload, bytes): elif isinstance(payload, bytes):
@@ -143,9 +139,7 @@ def validate_payload(payload: str | bytes | FilePayload) -> ValidationResult:
def validate_file_payload( def validate_file_payload(
file_data: bytes, file_data: bytes, filename: str = "", max_size: int = MAX_FILE_PAYLOAD_SIZE
filename: str = "",
max_size: int = MAX_FILE_PAYLOAD_SIZE
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate a file for embedding. Validate a file for embedding.
@@ -173,9 +167,7 @@ def validate_file_payload(
def validate_image( def validate_image(
image_data: bytes, image_data: bytes, name: str = "Image", check_size: bool = True
name: str = "Image",
check_size: bool = True
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate image data and dimensions. Validate image data and dimensions.
@@ -209,11 +201,7 @@ def validate_image(
) )
return ValidationResult.ok( return ValidationResult.ok(
width=width, width=width, height=height, pixels=num_pixels, mode=img.mode, format=img.format
height=height,
pixels=num_pixels,
mode=img.mode,
format=img.format
) )
except Exception as e: except Exception as e:
@@ -221,9 +209,7 @@ def validate_image(
def validate_rsa_key( def validate_rsa_key(
key_data: bytes, key_data: bytes, password: str | None = None, required: bool = False
password: str | None = None,
required: bool = False
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate RSA private key. Validate RSA private key.
@@ -256,10 +242,7 @@ def validate_rsa_key(
return ValidationResult.error(str(e)) return ValidationResult.error(str(e))
def validate_security_factors( def validate_security_factors(pin: str, rsa_key_data: bytes | None) -> ValidationResult:
pin: str,
rsa_key_data: bytes | None
) -> ValidationResult:
""" """
Validate that at least one security factor is provided. Validate that at least one security factor is provided.
@@ -274,17 +257,13 @@ def validate_security_factors(
has_key = bool(rsa_key_data and len(rsa_key_data) > 0) has_key = bool(rsa_key_data and len(rsa_key_data) > 0)
if not has_pin and not has_key: if not has_pin and not has_key:
return ValidationResult.error( return ValidationResult.error("You must provide at least a PIN or RSA Key")
"You must provide at least a PIN or RSA Key"
)
return ValidationResult.ok(has_pin=has_pin, has_key=has_key) return ValidationResult.ok(has_pin=has_pin, has_key=has_key)
def validate_file_extension( def validate_file_extension(
filename: str, filename: str, allowed: set[str], file_type: str = "File"
allowed: set[str],
file_type: str = "File"
) -> ValidationResult: ) -> ValidationResult:
""" """
Validate file extension. Validate file extension.
@@ -297,10 +276,10 @@ def validate_file_extension(
Returns: Returns:
ValidationResult with extension ValidationResult with extension
""" """
if not filename or '.' not in filename: if not filename or "." not in filename:
return ValidationResult.error(f"{file_type} must have a file extension") return ValidationResult.error(f"{file_type} must have a file extension")
ext = filename.rsplit('.', 1)[1].lower() ext = filename.rsplit(".", 1)[1].lower()
if ext not in allowed: if ext not in allowed:
return ValidationResult.error( return ValidationResult.error(
@@ -368,7 +347,7 @@ def validate_passphrase(passphrase: str) -> ValidationResult:
if len(words) < RECOMMENDED_PASSPHRASE_WORDS: if len(words) < RECOMMENDED_PASSPHRASE_WORDS:
return ValidationResult.ok( return ValidationResult.ok(
word_count=len(words), word_count=len(words),
warning=f"Recommend {RECOMMENDED_PASSPHRASE_WORDS}+ words for better security" warning=f"Recommend {RECOMMENDED_PASSPHRASE_WORDS}+ words for better security",
) )
return ValidationResult.ok(word_count=len(words)) return ValidationResult.ok(word_count=len(words))
@@ -378,6 +357,7 @@ def validate_passphrase(passphrase: str) -> ValidationResult:
# NEW VALIDATORS FOR V3.2.0 # NEW VALIDATORS FOR V3.2.0
# ============================================================================= # =============================================================================
def validate_reference_photo(photo_data: bytes) -> ValidationResult: def validate_reference_photo(photo_data: bytes) -> ValidationResult:
"""Validate reference photo. Alias for validate_image.""" """Validate reference photo. Alias for validate_image."""
return validate_image(photo_data, "Reference photo") return validate_image(photo_data, "Reference photo")
@@ -418,7 +398,7 @@ def validate_dct_output_format(format_str: str) -> ValidationResult:
Returns: Returns:
ValidationResult ValidationResult
""" """
valid_formats = {'png', 'jpeg'} valid_formats = {"png", "jpeg"}
if format_str.lower() not in valid_formats: if format_str.lower() not in valid_formats:
return ValidationResult.error( return ValidationResult.error(
@@ -438,7 +418,7 @@ def validate_dct_color_mode(mode: str) -> ValidationResult:
Returns: Returns:
ValidationResult ValidationResult
""" """
valid_modes = {'grayscale', 'color'} valid_modes = {"grayscale", "color"}
if mode.lower() not in valid_modes: if mode.lower() not in valid_modes:
return ValidationResult.error( return ValidationResult.error(
@@ -452,6 +432,7 @@ def validate_dct_color_mode(mode: str) -> ValidationResult:
# EXCEPTION-RAISING VALIDATORS (for CLI/API use) # EXCEPTION-RAISING VALIDATORS (for CLI/API use)
# ============================================================================ # ============================================================================
def require_valid_pin(pin: str, required: bool = False) -> None: def require_valid_pin(pin: str, required: bool = False) -> None:
"""Validate PIN, raising exception on failure.""" """Validate PIN, raising exception on failure."""
result = validate_pin(pin, required) result = validate_pin(pin, required)
@@ -481,9 +462,7 @@ def require_valid_image(image_data: bytes, name: str = "Image") -> None:
def require_valid_rsa_key( def require_valid_rsa_key(
key_data: bytes, key_data: bytes, password: str | None = None, required: bool = False
password: str | None = None,
required: bool = False
) -> None: ) -> None:
"""Validate RSA key, raising exception on failure.""" """Validate RSA key, raising exception on failure."""
result = validate_rsa_key(key_data, password, required) result = validate_rsa_key(key_data, password, required)

View File

@@ -41,8 +41,8 @@ def sample_images(temp_dir):
images = [] images = []
for i in range(3): for i in range(3):
img_path = temp_dir / f"test_image_{i}.png" img_path = temp_dir / f"test_image_{i}.png"
img = Image.new('RGB', (100, 100), color=(i * 50, i * 50, i * 50)) img = Image.new("RGB", (100, 100), color=(i * 50, i * 50, i * 50))
img.save(img_path, 'PNG') img.save(img_path, "PNG")
images.append(img_path) images.append(img_path)
return images return images
@@ -55,9 +55,9 @@ def sample_reference_photo():
from PIL import Image from PIL import Image
img = Image.new('RGB', (100, 100), color=(128, 128, 128)) img = Image.new("RGB", (100, 100), color=(128, 128, 128))
buf = BytesIO() buf = BytesIO()
img.save(buf, 'PNG') img.save(buf, "PNG")
return buf.getvalue() return buf.getvalue()
@@ -67,7 +67,7 @@ def sample_credentials(sample_reference_photo):
return { return {
"reference_photo": sample_reference_photo, "reference_photo": sample_reference_photo,
"passphrase": "test phrase four words", # v3.2.0: single passphrase "passphrase": "test phrase four words", # v3.2.0: single passphrase
"pin": "123456" "pin": "123456",
} }
@@ -95,9 +95,9 @@ class TestBatchItem:
message="Done", message="Done",
) )
result = item.to_dict() result = item.to_dict()
assert result['input_path'] == "input.png" assert result["input_path"] == "input.png"
assert result['output_path'] == "output.png" assert result["output_path"] == "output.png"
assert result['status'] == "success" assert result["status"] == "success"
class TestBatchResult: class TestBatchResult:
@@ -106,11 +106,12 @@ class TestBatchResult:
def test_to_json(self): def test_to_json(self):
"""Should serialize to valid JSON.""" """Should serialize to valid JSON."""
import json import json
result = BatchResult(operation="encode", total=5, succeeded=4, failed=1) result = BatchResult(operation="encode", total=5, succeeded=4, failed=1)
json_str = result.to_json() json_str = result.to_json()
parsed = json.loads(json_str) parsed = json.loads(json_str)
assert parsed['operation'] == "encode" assert parsed["operation"] == "encode"
assert parsed['summary']['total'] == 5 assert parsed["summary"]["total"] == 5
def test_duration_with_end_time(self): def test_duration_with_end_time(self):
"""Duration should work when end_time is set.""" """Duration should work when end_time is set."""
@@ -128,7 +129,7 @@ class TestBatchCredentials:
data = { data = {
"reference_photo": sample_reference_photo, "reference_photo": sample_reference_photo,
"passphrase": "test phrase four words", "passphrase": "test phrase four words",
"pin": "123456" "pin": "123456",
} }
creds = BatchCredentials.from_dict(data) creds = BatchCredentials.from_dict(data)
assert creds.passphrase == "test phrase four words" assert creds.passphrase == "test phrase four words"
@@ -139,7 +140,7 @@ class TestBatchCredentials:
data = { data = {
"reference_photo": sample_reference_photo, "reference_photo": sample_reference_photo,
"day_phrase": "legacy phrase here", # Old key name "day_phrase": "legacy phrase here", # Old key name
"pin": "123456" "pin": "123456",
} }
creds = BatchCredentials.from_dict(data) creds = BatchCredentials.from_dict(data)
# Should accept old key and map to passphrase # Should accept old key and map to passphrase
@@ -151,19 +152,19 @@ class TestBatchCredentials:
creds = BatchCredentials( creds = BatchCredentials(
reference_photo=sample_reference_photo, reference_photo=sample_reference_photo,
passphrase="test phrase four words", passphrase="test phrase four words",
pin="123456" pin="123456",
) )
result = creds.to_dict() result = creds.to_dict()
assert result['passphrase'] == "test phrase four words" assert result["passphrase"] == "test phrase four words"
assert result['pin'] == "123456" assert result["pin"] == "123456"
assert 'day_phrase' not in result # Old key should not be present assert "day_phrase" not in result # Old key should not be present
def test_passphrase_is_string(self, sample_reference_photo): def test_passphrase_is_string(self, sample_reference_photo):
"""Passphrase should be a string, not a dict.""" """Passphrase should be a string, not a dict."""
creds = BatchCredentials( creds = BatchCredentials(
reference_photo=sample_reference_photo, reference_photo=sample_reference_photo,
passphrase="test phrase four words", passphrase="test phrase four words",
pin="123456" pin="123456",
) )
assert isinstance(creds.passphrase, str) assert isinstance(creds.passphrase, str)
@@ -216,7 +217,7 @@ class TestBatchProcessor:
nested = temp_dir / "nested" nested = temp_dir / "nested"
nested.mkdir() nested.mkdir()
img_path = nested / "nested.png" img_path = nested / "nested.png"
img = Image.new('RGB', (50, 50)) img = Image.new("RGB", (50, 50))
img.save(img_path) img.save(img_path)
processor = BatchProcessor() processor = BatchProcessor()
@@ -241,7 +242,9 @@ class TestBatchProcessor:
message="test", message="test",
) )
def test_batch_encode_accepts_passphrase_credentials(self, sample_images, temp_dir, sample_credentials): def test_batch_encode_accepts_passphrase_credentials(
self, sample_images, temp_dir, sample_credentials
):
"""Should accept v3.2.0 format credentials with passphrase.""" """Should accept v3.2.0 format credentials with passphrase."""
processor = BatchProcessor() processor = BatchProcessor()
result = processor.batch_encode( result = processor.batch_encode(
@@ -343,9 +346,9 @@ class TestBatchCapacityCheck:
"""Results should include capacity info.""" """Results should include capacity info."""
results = batch_capacity_check(sample_images) results = batch_capacity_check(sample_images)
for item in results: for item in results:
assert 'capacity_bytes' in item assert "capacity_bytes" in item
assert 'dimensions' in item assert "dimensions" in item
assert 'valid' in item assert "valid" in item
def test_handles_invalid_files(self, temp_dir): def test_handles_invalid_files(self, temp_dir):
"""Should handle non-image files gracefully.""" """Should handle non-image files gracefully."""
@@ -354,7 +357,7 @@ class TestBatchCapacityCheck:
results = batch_capacity_check([bad_file]) results = batch_capacity_check([bad_file])
assert len(results) == 1 assert len(results) == 1
assert 'error' in results[0] assert "error" in results[0]
class TestPrintBatchResult: class TestPrintBatchResult:
@@ -403,7 +406,7 @@ class TestCredentialsMigration:
old_format = { old_format = {
"reference_photo": sample_reference_photo, "reference_photo": sample_reference_photo,
"phrase": "old style phrase", "phrase": "old style phrase",
"pin": "123456" "pin": "123456",
} }
# Should not raise # Should not raise
creds = BatchCredentials.from_dict(old_format) creds = BatchCredentials.from_dict(old_format)
@@ -414,7 +417,7 @@ class TestCredentialsMigration:
old_format = { old_format = {
"reference_photo": sample_reference_photo, "reference_photo": sample_reference_photo,
"day_phrase": "old day phrase", "day_phrase": "old day phrase",
"pin": "123456" "pin": "123456",
} }
creds = BatchCredentials.from_dict(old_format) creds = BatchCredentials.from_dict(old_format)
assert creds.passphrase == "old day phrase" assert creds.passphrase == "old day phrase"
@@ -425,7 +428,7 @@ class TestCredentialsMigration:
"reference_photo": sample_reference_photo, "reference_photo": sample_reference_photo,
"passphrase": "new style passphrase", "passphrase": "new style passphrase",
"day_phrase": "old day phrase", "day_phrase": "old day phrase",
"pin": "123456" "pin": "123456",
} }
creds = BatchCredentials.from_dict(mixed_format) creds = BatchCredentials.from_dict(mixed_format)
assert creds.passphrase == "new style passphrase" assert creds.passphrase == "new style passphrase"

View File

@@ -42,6 +42,7 @@ class TestCompress:
def test_compress_incompressible_data(self): def test_compress_incompressible_data(self):
"""Incompressible data should be stored uncompressed.""" """Incompressible data should be stored uncompressed."""
import os import os
# Random data doesn't compress well # Random data doesn't compress well
data = os.urandom(500) data = os.urandom(500)
result = compress(data, CompressionAlgorithm.ZLIB) result = compress(data, CompressionAlgorithm.ZLIB)
@@ -107,6 +108,7 @@ class TestDecompress:
def test_roundtrip_large_data(self): def test_roundtrip_large_data(self):
"""Large data should survive compress/decompress roundtrip.""" """Large data should survive compress/decompress roundtrip."""
import os import os
original = os.urandom(50000) original = os.urandom(50000)
compressed = compress(original) compressed = compress(original)
result = decompress(compressed) result = decompress(compressed)
@@ -173,7 +175,7 @@ class TestEdgeCases:
def test_unicode_after_encoding(self): def test_unicode_after_encoding(self):
"""UTF-8 encoded Unicode should compress correctly.""" """UTF-8 encoded Unicode should compress correctly."""
text = "Hello, 世界! 🎉 " * 100 text = "Hello, 世界! 🎉 " * 100
data = text.encode('utf-8') data = text.encode("utf-8")
compressed = compress(data) compressed = compress(data)
result = decompress(compressed) result = decompress(compressed)
assert result.decode('utf-8') == text assert result.decode("utf-8") == text

View File

@@ -37,12 +37,13 @@ from stegasoo.steganography import get_output_format
# Fixtures # Fixtures
# ============================================================================= # =============================================================================
@pytest.fixture @pytest.fixture
def png_image(): def png_image():
"""Create a test PNG image.""" """Create a test PNG image."""
img = Image.new('RGB', (100, 100), color='red') img = Image.new("RGB", (100, 100), color="red")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format='PNG') img.save(buf, format="PNG")
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
@@ -50,9 +51,9 @@ def png_image():
@pytest.fixture @pytest.fixture
def large_png_image(): def large_png_image():
"""Create a larger test PNG image for DCT mode.""" """Create a larger test PNG image for DCT mode."""
img = Image.new('RGB', (400, 400), color='blue') img = Image.new("RGB", (400, 400), color="blue")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format='PNG') img.save(buf, format="PNG")
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
@@ -60,9 +61,9 @@ def large_png_image():
@pytest.fixture @pytest.fixture
def bmp_image(): def bmp_image():
"""Create a test BMP image.""" """Create a test BMP image."""
img = Image.new('RGB', (100, 100), color='blue') img = Image.new("RGB", (100, 100), color="blue")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format='BMP') img.save(buf, format="BMP")
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
@@ -70,9 +71,9 @@ def bmp_image():
@pytest.fixture @pytest.fixture
def jpeg_image(): def jpeg_image():
"""Create a test JPEG image.""" """Create a test JPEG image."""
img = Image.new('RGB', (100, 100), color='green') img = Image.new("RGB", (100, 100), color="green")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format='JPEG') img.save(buf, format="JPEG")
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
@@ -80,9 +81,9 @@ def jpeg_image():
@pytest.fixture @pytest.fixture
def gif_image(): def gif_image():
"""Create a test GIF image.""" """Create a test GIF image."""
img = Image.new('RGB', (100, 100), color='yellow') img = Image.new("RGB", (100, 100), color="yellow")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format='GIF') img.save(buf, format="GIF")
buf.seek(0) buf.seek(0)
return buf.getvalue() return buf.getvalue()
@@ -91,6 +92,7 @@ def gif_image():
# Key Generation Tests (v3.2.0 Updated) # Key Generation Tests (v3.2.0 Updated)
# ============================================================================= # =============================================================================
class TestKeygen: class TestKeygen:
"""Tests for key generation functions.""" """Tests for key generation functions."""
@@ -99,7 +101,7 @@ class TestKeygen:
pin = generate_pin() pin = generate_pin()
assert len(pin) == 6 assert len(pin) == 6
assert pin.isdigit() assert pin.isdigit()
assert pin[0] != '0' assert pin[0] != "0"
def test_generate_pin_lengths(self): def test_generate_pin_lengths(self):
"""PIN generation should work for all valid lengths.""" """PIN generation should work for all valid lengths."""
@@ -129,7 +131,7 @@ class TestKeygen:
# v3.2.0: Single passphrase instead of 7 daily phrases # v3.2.0: Single passphrase instead of 7 daily phrases
assert creds.passphrase is not None assert creds.passphrase is not None
assert isinstance(creds.passphrase, str) assert isinstance(creds.passphrase, str)
assert ' ' in creds.passphrase # Should have multiple words assert " " in creds.passphrase # Should have multiple words
def test_generate_credentials_rsa_only(self): def test_generate_credentials_rsa_only(self):
"""RSA-only credentials should have single passphrase.""" """RSA-only credentials should have single passphrase."""
@@ -180,6 +182,7 @@ class TestKeygen:
# Validation Tests (v3.2.0 Updated) # Validation Tests (v3.2.0 Updated)
# ============================================================================= # =============================================================================
class TestValidation: class TestValidation:
"""Tests for validation functions.""" """Tests for validation functions."""
@@ -250,56 +253,59 @@ class TestValidation:
# Output Format Tests # Output Format Tests
# ============================================================================= # =============================================================================
class TestOutputFormat: class TestOutputFormat:
"""Tests for output format handling.""" """Tests for output format handling."""
def test_png_stays_png(self): def test_png_stays_png(self):
"""PNG input should produce PNG output.""" """PNG input should produce PNG output."""
fmt, ext = get_output_format('PNG') fmt, ext = get_output_format("PNG")
assert fmt == 'PNG' assert fmt == "PNG"
assert ext == 'png' assert ext == "png"
def test_bmp_stays_bmp(self): def test_bmp_stays_bmp(self):
"""BMP input should produce BMP output.""" """BMP input should produce BMP output."""
fmt, ext = get_output_format('BMP') fmt, ext = get_output_format("BMP")
assert fmt == 'BMP' assert fmt == "BMP"
assert ext == 'bmp' assert ext == "bmp"
def test_jpeg_becomes_png(self): def test_jpeg_becomes_png(self):
"""JPEG input should produce PNG output (lossless).""" """JPEG input should produce PNG output (lossless)."""
fmt, ext = get_output_format('JPEG') fmt, ext = get_output_format("JPEG")
assert fmt == 'PNG' assert fmt == "PNG"
assert ext == 'png' assert ext == "png"
def test_gif_becomes_png(self): def test_gif_becomes_png(self):
"""GIF input should produce PNG output.""" """GIF input should produce PNG output."""
fmt, ext = get_output_format('GIF') fmt, ext = get_output_format("GIF")
assert fmt == 'PNG' assert fmt == "PNG"
assert ext == 'png' assert ext == "png"
def test_none_becomes_png(self): def test_none_becomes_png(self):
"""None format should default to PNG.""" """None format should default to PNG."""
fmt, ext = get_output_format(None) fmt, ext = get_output_format(None)
assert fmt == 'PNG' assert fmt == "PNG"
assert ext == 'png' assert ext == "png"
def test_unknown_becomes_png(self): def test_unknown_becomes_png(self):
"""Unknown format should default to PNG.""" """Unknown format should default to PNG."""
fmt, ext = get_output_format('UNKNOWN') fmt, ext = get_output_format("UNKNOWN")
assert fmt == 'PNG' assert fmt == "PNG"
assert ext == 'png' assert ext == "png"
# ============================================================================= # =============================================================================
# Header Overhead Test (v4.0.0) # Header Overhead Test (v4.0.0)
# ============================================================================= # =============================================================================
class TestConstants: class TestConstants:
"""Tests for constants and configuration.""" """Tests for constants and configuration."""
def test_header_overhead_value(self): def test_header_overhead_value(self):
"""Header overhead should be 66 bytes (v4.0.0: added flags byte).""" """Header overhead should be 66 bytes (v4.0.0: added flags byte)."""
from stegasoo.steganography import HEADER_OVERHEAD from stegasoo.steganography import HEADER_OVERHEAD
assert HEADER_OVERHEAD == 66 assert HEADER_OVERHEAD == 66
@@ -307,6 +313,7 @@ class TestConstants:
# Encode/Decode Tests (v4.0.0 Updated) # Encode/Decode Tests (v4.0.0 Updated)
# ============================================================================= # =============================================================================
class TestEncodeDecode: class TestEncodeDecode:
"""Tests for encoding and decoding functions.""" """Tests for encoding and decoding functions."""
@@ -322,19 +329,19 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
assert result.stego_image is not None assert result.stego_image is not None
assert len(result.stego_image) > 0 assert len(result.stego_image) > 0
assert result.filename.endswith('.png') assert result.filename.endswith(".png")
# v3.2.0: Use passphrase parameter, no date_str # v3.2.0: Use passphrase parameter, no date_str
decoded = decode( decoded = decode(
stego_image=result.stego_image, stego_image=result.stego_image,
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
assert decoded.message == message assert decoded.message == message
@@ -350,7 +357,7 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
# decode_text returns string directly # decode_text returns string directly
@@ -358,7 +365,7 @@ class TestEncodeDecode:
stego_image=result.stego_image, stego_image=result.stego_image,
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
assert decoded_text == message assert decoded_text == message
@@ -370,9 +377,9 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
passphrase="test phrase here now", passphrase="test phrase here now",
pin="123456" pin="123456",
) )
assert result.filename.endswith('.png') assert result.filename.endswith(".png")
def test_bmp_carrier_produces_bmp(self, bmp_image, png_image): def test_bmp_carrier_produces_bmp(self, bmp_image, png_image):
"""BMP carrier should produce BMP output.""" """BMP carrier should produce BMP output."""
@@ -381,9 +388,9 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=bmp_image, carrier_image=bmp_image,
passphrase="test phrase here now", passphrase="test phrase here now",
pin="123456" pin="123456",
) )
assert result.filename.endswith('.bmp') assert result.filename.endswith(".bmp")
def test_jpeg_carrier_produces_png(self, jpeg_image, png_image): def test_jpeg_carrier_produces_png(self, jpeg_image, png_image):
"""JPEG carrier should produce PNG output (lossless).""" """JPEG carrier should produce PNG output (lossless)."""
@@ -392,9 +399,9 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=jpeg_image, carrier_image=jpeg_image,
passphrase="test phrase here now", passphrase="test phrase here now",
pin="123456" pin="123456",
) )
assert result.filename.endswith('.png') assert result.filename.endswith(".png")
def test_bmp_roundtrip(self, bmp_image, png_image): def test_bmp_roundtrip(self, bmp_image, png_image):
"""Full encode/decode cycle with BMP should work.""" """Full encode/decode cycle with BMP should work."""
@@ -407,15 +414,15 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=bmp_image, carrier_image=bmp_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
assert result.filename.endswith('.bmp') assert result.filename.endswith(".bmp")
decoded = decode( decoded = decode(
stego_image=result.stego_image, stego_image=result.stego_image,
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
assert decoded.message == message assert decoded.message == message
@@ -427,7 +434,7 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
passphrase="test phrase here now", passphrase="test phrase here now",
pin="123456" pin="123456",
) )
with pytest.raises((stegasoo.DecryptionError, stegasoo.ExtractionError)): with pytest.raises((stegasoo.DecryptionError, stegasoo.ExtractionError)):
@@ -435,7 +442,7 @@ class TestEncodeDecode:
stego_image=result.stego_image, stego_image=result.stego_image,
reference_photo=png_image, reference_photo=png_image,
passphrase="test phrase here now", passphrase="test phrase here now",
pin="654321" # Wrong PIN pin="654321", # Wrong PIN
) )
def test_wrong_passphrase_fails(self, png_image): def test_wrong_passphrase_fails(self, png_image):
@@ -445,7 +452,7 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
passphrase="correct phrase here now", passphrase="correct phrase here now",
pin="123456" pin="123456",
) )
with pytest.raises((stegasoo.DecryptionError, stegasoo.ExtractionError)): with pytest.raises((stegasoo.DecryptionError, stegasoo.ExtractionError)):
@@ -453,7 +460,7 @@ class TestEncodeDecode:
stego_image=result.stego_image, stego_image=result.stego_image,
reference_photo=png_image, reference_photo=png_image,
passphrase="wrong phrase here now", # Wrong passphrase passphrase="wrong phrase here now", # Wrong passphrase
pin="123456" pin="123456",
) )
def test_unicode_message(self, png_image): def test_unicode_message(self, png_image):
@@ -467,14 +474,14 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
decoded = decode( decoded = decode(
stego_image=result.stego_image, stego_image=result.stego_image,
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
assert decoded.message == message assert decoded.message == message
@@ -486,18 +493,20 @@ class TestEncodeDecode:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
passphrase="test phrase here now", passphrase="test phrase here now",
pin="123456" pin="123456",
) )
# Filename format: {random_hex}_{YYYYMMDD}.{ext} # Filename format: {random_hex}_{YYYYMMDD}.{ext}
# e.g., "a1b2c3d4_20251227.png" # e.g., "a1b2c3d4_20251227.png"
import re import re
assert re.search(r'^[a-f0-9]{8}_\d{8}\.png$', result.filename)
assert re.search(r"^[a-f0-9]{8}_\d{8}\.png$", result.filename)
# ============================================================================= # =============================================================================
# DCT Mode Tests (v3.2.0) # DCT Mode Tests (v3.2.0)
# ============================================================================= # =============================================================================
class TestDCTMode: class TestDCTMode:
"""Tests for DCT steganography mode.""" """Tests for DCT steganography mode."""
@@ -519,7 +528,7 @@ class TestDCTMode:
carrier_image=large_png_image, carrier_image=large_png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
embed_mode='dct' embed_mode="dct",
) )
assert result.stego_image is not None assert result.stego_image is not None
@@ -528,7 +537,7 @@ class TestDCTMode:
stego_image=result.stego_image, stego_image=result.stego_image,
reference_photo=large_png_image, reference_photo=large_png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin pin=pin,
) )
assert decoded.message == message assert decoded.message == message
@@ -545,7 +554,7 @@ class TestDCTMode:
carrier_image=large_png_image, carrier_image=large_png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
embed_mode='dct' embed_mode="dct",
) )
# Decode with auto mode (default) # Decode with auto mode (default)
@@ -554,7 +563,7 @@ class TestDCTMode:
reference_photo=large_png_image, reference_photo=large_png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
embed_mode='auto' embed_mode="auto",
) )
assert decoded.message == message assert decoded.message == message
@@ -564,19 +573,20 @@ class TestDCTMode:
# Version Tests # Version Tests
# ============================================================================= # =============================================================================
class TestVersion: class TestVersion:
"""Tests for version information.""" """Tests for version information."""
def test_version_exists(self): def test_version_exists(self):
"""Version string should exist and be valid.""" """Version string should exist and be valid."""
assert hasattr(stegasoo, '__version__') assert hasattr(stegasoo, "__version__")
parts = stegasoo.__version__.split('.') parts = stegasoo.__version__.split(".")
assert len(parts) >= 2 assert len(parts) >= 2
assert all(p.isdigit() for p in parts[:2]) assert all(p.isdigit() for p in parts[:2])
def test_version_is_4_0_0(self): def test_version_is_4_0_0(self):
"""Version should be 4.0.0 or higher.""" """Version should be 4.0.0 or higher."""
parts = stegasoo.__version__.split('.') parts = stegasoo.__version__.split(".")
major = int(parts[0]) major = int(parts[0])
assert major >= 4 assert major >= 4
@@ -585,6 +595,7 @@ class TestVersion:
# Backward Compatibility Tests # Backward Compatibility Tests
# ============================================================================= # =============================================================================
class TestBackwardCompatibility: class TestBackwardCompatibility:
"""Tests for backward compatibility handling.""" """Tests for backward compatibility handling."""
@@ -596,7 +607,7 @@ class TestBackwardCompatibility:
reference_photo=png_image, reference_photo=png_image,
carrier_image=png_image, carrier_image=png_image,
day_phrase="old style phrase", # Old parameter name day_phrase="old style phrase", # Old parameter name
pin="123456" pin="123456",
) )
def test_old_date_str_parameter_raises(self, png_image): def test_old_date_str_parameter_raises(self, png_image):
@@ -608,7 +619,7 @@ class TestBackwardCompatibility:
carrier_image=png_image, carrier_image=png_image,
passphrase="test phrase here now", passphrase="test phrase here now",
pin="123456", pin="123456",
date_str="2025-01-01" # Removed parameter date_str="2025-01-01", # Removed parameter
) )
@@ -616,6 +627,7 @@ class TestBackwardCompatibility:
# Channel Key Tests (v4.0.0) # Channel Key Tests (v4.0.0)
# ============================================================================= # =============================================================================
class TestChannelKey: class TestChannelKey:
"""Tests for channel key functionality (v4.0.0).""" """Tests for channel key functionality (v4.0.0)."""
@@ -624,7 +636,7 @@ class TestChannelKey:
key = generate_channel_key() key = generate_channel_key()
# Format: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX (8 groups of 4) # Format: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX (8 groups of 4)
assert len(key) == 39 assert len(key) == 39
parts = key.split('-') parts = key.split("-")
assert len(parts) == 8 assert len(parts) == 8
for part in parts: for part in parts:
assert len(part) == 4 assert len(part) == 4
@@ -665,7 +677,7 @@ class TestChannelKey:
carrier_image=png_image, carrier_image=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key=channel_key channel_key=channel_key,
) )
assert result.stego_image is not None assert result.stego_image is not None
@@ -676,7 +688,7 @@ class TestChannelKey:
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key=channel_key channel_key=channel_key,
) )
assert decoded.message == message assert decoded.message == message
@@ -696,7 +708,7 @@ class TestChannelKey:
carrier_image=png_image, carrier_image=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key=channel_key1 channel_key=channel_key1,
) )
# Decode with different channel key should fail # Decode with different channel key should fail
@@ -706,7 +718,7 @@ class TestChannelKey:
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key=channel_key2 channel_key=channel_key2,
) )
def test_encode_decode_public_mode(self, png_image): def test_encode_decode_public_mode(self, png_image):
@@ -722,7 +734,7 @@ class TestChannelKey:
carrier_image=png_image, carrier_image=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key="" # Explicit public mode channel_key="", # Explicit public mode
) )
# Decode without channel key # Decode without channel key
@@ -731,7 +743,7 @@ class TestChannelKey:
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key="" # Explicit public mode channel_key="", # Explicit public mode
) )
assert decoded.message == message assert decoded.message == message
@@ -749,7 +761,7 @@ class TestChannelKey:
carrier_image=png_image, carrier_image=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key="" # Public mode channel_key="", # Public mode
) )
# Decode with channel key should fail # Decode with channel key should fail
@@ -760,5 +772,5 @@ class TestChannelKey:
reference_photo=png_image, reference_photo=png_image,
passphrase=passphrase, passphrase=passphrase,
pin=pin, pin=pin,
channel_key=channel_key channel_key=channel_key,
) )