Apply black formatter to all Python files

Reformatted 29 files for consistent code style and CI compliance.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee
2026-01-02 17:44:41 -05:00
parent 221678d934
commit afa88bc73b
29 changed files with 2067 additions and 1814 deletions

View File

@@ -31,7 +31,7 @@ from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel, Field
# Add parent to path for development
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src'))
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
from stegasoo import (
MAX_FILE_PAYLOAD_SIZE,
@@ -71,6 +71,7 @@ try:
extract_key_from_qr,
has_qr_read,
)
HAS_QR_READ = has_qr_read()
except ImportError:
HAS_QR_READ = False
@@ -130,6 +131,7 @@ DctOutputFormatType = Literal["png", "jpeg"]
# MODELS
# ============================================================================
class GenerateRequest(BaseModel):
use_pin: bool = True
use_rsa: bool = False
@@ -139,7 +141,7 @@ class GenerateRequest(BaseModel):
default=DEFAULT_PASSPHRASE_WORDS,
ge=MIN_PASSPHRASE_WORDS,
le=MAX_PASSPHRASE_WORDS,
description="Words per passphrase (v3.2.0: default increased to 4)"
description="Words per passphrase (v3.2.0: default increased to 4)",
)
@@ -150,8 +152,7 @@ class GenerateResponse(BaseModel):
entropy: dict[str, int]
# Legacy field for compatibility
phrases: dict[str, str] | None = Field(
default=None,
description="Deprecated: Use 'passphrase' instead"
default=None, description="Deprecated: Use 'passphrase' instead"
)
@@ -166,24 +167,25 @@ class EncodeRequest(BaseModel):
# Channel key (v4.0.0)
channel_key: str | None = Field(
default=None,
description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key"
description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key",
)
embed_mode: EmbedModeType = Field(
default="lsb",
description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)"
description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)",
)
dct_output_format: DctOutputFormatType = Field(
default="png",
description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode."
description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode.",
)
dct_color_mode: DctColorModeType = Field(
default="grayscale",
description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode."
description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode.",
)
class EncodeFileRequest(BaseModel):
"""Request for embedding a file (base64-encoded)."""
file_data_base64: str
filename: str
mime_type: str | None = None
@@ -196,19 +198,19 @@ class EncodeFileRequest(BaseModel):
# Channel key (v4.0.0)
channel_key: str | None = Field(
default=None,
description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key"
description="Channel key for deployment isolation. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key",
)
embed_mode: EmbedModeType = Field(
default="lsb",
description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)"
description="Embedding mode: 'lsb' (default, color) or 'dct' (requires scipy)",
)
dct_output_format: DctOutputFormatType = Field(
default="png",
description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode."
description="DCT output format: 'png' (lossless) or 'jpeg' (smaller). Only applies to DCT mode.",
)
dct_color_mode: DctColorModeType = Field(
default="grayscale",
description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode."
description="DCT color mode: 'grayscale' (default) or 'color' (preserves colors). Only applies to DCT mode.",
)
@@ -218,30 +220,23 @@ class EncodeResponse(BaseModel):
capacity_used_percent: float
embed_mode: str = Field(description="Embedding mode used: 'lsb' or 'dct'")
output_format: str = Field(
default="png",
description="Output format: 'png' or 'jpeg' (for DCT mode)"
default="png", description="Output format: 'png' or 'jpeg' (for DCT mode)"
)
color_mode: str = Field(
default="color",
description="Color mode: 'color' (LSB/DCT color) or 'grayscale' (DCT grayscale)"
description="Color mode: 'color' (LSB/DCT color) or 'grayscale' (DCT grayscale)",
)
# Channel key info (v4.0.0)
channel_mode: str = Field(
default="public",
description="Channel mode: 'public' or 'private'"
)
channel_mode: str = Field(default="public", description="Channel mode: 'public' or 'private'")
channel_fingerprint: str | None = Field(
default=None,
description="Channel key fingerprint (if private mode)"
default=None, description="Channel key fingerprint (if private mode)"
)
# Legacy fields (v3.2.0: no longer used in crypto)
date_used: str | None = Field(
default=None,
description="Deprecated: Date no longer used in v3.2.0"
default=None, description="Deprecated: Date no longer used in v3.2.0"
)
day_of_week: str | None = Field(
default=None,
description="Deprecated: Date no longer used in v3.2.0"
default=None, description="Deprecated: Date no longer used in v3.2.0"
)
@@ -255,16 +250,16 @@ class DecodeRequest(BaseModel):
# Channel key (v4.0.0)
channel_key: str | None = Field(
default=None,
description="Channel key for decryption. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key"
description="Channel key for decryption. null=auto (use server config), ''=public mode, 'XXXX-...'=explicit key",
)
embed_mode: ExtractModeType = Field(
default="auto",
description="Extraction mode: 'auto' (default), 'lsb', or 'dct'"
default="auto", description="Extraction mode: 'auto' (default), 'lsb', or 'dct'"
)
class DecodeResponse(BaseModel):
"""Response for decode - can be text or file."""
payload_type: str # 'text' or 'file'
message: str | None = None # For text
file_data_base64: str | None = None # For file (base64-encoded)
@@ -274,6 +269,7 @@ class DecodeResponse(BaseModel):
class ModeCapacity(BaseModel):
"""Capacity info for a single mode."""
capacity_bytes: int
capacity_kb: float
available: bool
@@ -287,22 +283,22 @@ class ImageInfoResponse(BaseModel):
capacity_bytes: int = Field(description="LSB mode capacity (for backwards compatibility)")
capacity_kb: int = Field(description="LSB mode capacity in KB")
modes: dict[str, ModeCapacity] | None = Field(
default=None,
description="Capacity by embedding mode (v3.0+)"
default=None, description="Capacity by embedding mode (v3.0+)"
)
class CompareModesRequest(BaseModel):
"""Request for comparing embedding modes."""
carrier_image_base64: str
payload_size: int | None = Field(
default=None,
description="Optional payload size to check if it fits"
default=None, description="Optional payload size to check if it fits"
)
class CompareModesResponse(BaseModel):
"""Response comparing LSB and DCT modes."""
width: int
height: int
lsb: dict
@@ -313,6 +309,7 @@ class CompareModesResponse(BaseModel):
class DctModeInfo(BaseModel):
"""Detailed DCT mode information."""
available: bool
name: str
description: str
@@ -324,6 +321,7 @@ class DctModeInfo(BaseModel):
class ChannelStatusResponse(BaseModel):
"""Response for channel key status (v4.0.0)."""
mode: str = Field(description="'public' or 'private'")
configured: bool = Field(description="Whether a channel key is configured")
fingerprint: str | None = Field(default=None, description="Key fingerprint (partial)")
@@ -333,6 +331,7 @@ class ChannelStatusResponse(BaseModel):
class ChannelGenerateResponse(BaseModel):
"""Response for channel key generation (v4.0.0)."""
key: str = Field(description="Generated channel key")
fingerprint: str = Field(description="Key fingerprint")
saved: bool = Field(default=False, description="Whether key was saved to config")
@@ -341,19 +340,18 @@ class ChannelGenerateResponse(BaseModel):
class ChannelSetRequest(BaseModel):
"""Request to set channel key (v4.0.0)."""
key: str = Field(description="Channel key to set")
location: str = Field(default="user", description="'user' or 'project'")
class ModesResponse(BaseModel):
"""Response showing available embedding modes."""
lsb: dict
dct: DctModeInfo
# Channel key status (v4.0.0)
channel: dict | None = Field(
default=None,
description="Channel key status (v4.0.0)"
)
channel: dict | None = Field(default=None, description="Channel key status (v4.0.0)")
class StatusResponse(BaseModel):
@@ -363,18 +361,10 @@ class StatusResponse(BaseModel):
has_dct: bool
max_payload_kb: int
available_modes: list[str]
dct_features: dict | None = Field(
default=None,
description="DCT mode features (v3.0.1+)"
)
dct_features: dict | None = Field(default=None, description="DCT mode features (v3.0.1+)")
# Channel key status (v4.0.0)
channel: dict | None = Field(
default=None,
description="Channel key status (v4.0.0)"
)
breaking_changes: dict = Field(
description="v4.0.0 breaking changes"
)
channel: dict | None = Field(default=None, description="Channel key status (v4.0.0)")
breaking_changes: dict = Field(description="v4.0.0 breaking changes")
class QrExtractResponse(BaseModel):
@@ -385,6 +375,7 @@ class QrExtractResponse(BaseModel):
class WillFitRequest(BaseModel):
"""Request to check if payload will fit."""
carrier_image_base64: str
payload_size: int
embed_mode: EmbedModeType = "lsb"
@@ -392,6 +383,7 @@ class WillFitRequest(BaseModel):
class WillFitResponse(BaseModel):
"""Response for will_fit check."""
fits: bool
payload_size: int
capacity: int
@@ -409,6 +401,7 @@ class ErrorResponse(BaseModel):
# HELPER: RESOLVE CHANNEL KEY
# ============================================================================
def _resolve_channel_key(channel_key: str | None) -> str | None:
"""
Resolve channel key from API parameter.
@@ -436,8 +429,7 @@ def _resolve_channel_key(channel_key: str | None) -> str | None:
# Explicit key - validate format
if not validate_channel_key(channel_key):
raise HTTPException(
400,
"Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
400, "Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
)
return channel_key
@@ -461,7 +453,7 @@ def _get_channel_info(channel_key: str | None) -> tuple[str, str | None]:
# Auto mode - check server config
if has_channel_key():
status = get_channel_status()
return "private", status.get('fingerprint')
return "private", status.get("fingerprint")
return "public", None
@@ -470,6 +462,7 @@ def _get_channel_info(channel_key: str | None) -> tuple[str, str | None]:
# ROUTES - STATUS & INFO
# ============================================================================
@app.get("/", response_model=StatusResponse)
async def root():
"""Get API status and configuration."""
@@ -488,10 +481,10 @@ async def root():
# Channel key status (v4.0.0)
channel_status = get_channel_status()
channel_info = {
"mode": channel_status['mode'],
"configured": channel_status['configured'],
"fingerprint": channel_status.get('fingerprint'),
"source": channel_status.get('source'),
"mode": channel_status["mode"],
"configured": channel_status["configured"],
"fingerprint": channel_status.get("fingerprint"),
"source": channel_status.get("source"),
}
return StatusResponse(
@@ -510,8 +503,8 @@ async def root():
"v3_notes": {
"date_removed": "No date_str parameter needed - encode/decode anytime",
"passphrase_renamed": "day_phrase → passphrase (single passphrase, no daily rotation)",
}
}
},
},
)
@@ -525,9 +518,9 @@ async def api_modes():
# Channel status
channel_status = get_channel_status()
channel_info = {
"mode": channel_status['mode'],
"configured": channel_status['configured'],
"fingerprint": channel_status.get('fingerprint'),
"mode": channel_status["mode"],
"configured": channel_status["configured"],
"fingerprint": channel_status.get("fingerprint"),
}
return ModesResponse(
@@ -555,6 +548,7 @@ async def api_modes():
# ROUTES - CHANNEL KEY (v4.0.0)
# ============================================================================
@app.get("/channel/status", response_model=ChannelStatusResponse)
async def api_channel_status(
reveal: bool = Query(False, description="Include full key in response")
@@ -570,11 +564,11 @@ async def api_channel_status(
status = get_channel_status()
return ChannelStatusResponse(
mode=status['mode'],
configured=status['configured'],
fingerprint=status.get('fingerprint'),
source=status.get('source'),
key=status.get('key') if reveal and status['configured'] else None,
mode=status["mode"],
configured=status["configured"],
fingerprint=status.get("fingerprint"),
source=status.get("source"),
key=status.get("key") if reveal and status["configured"] else None,
)
@@ -601,11 +595,11 @@ async def api_channel_generate(
save_location = None
if save:
set_channel_key(key, location='user')
set_channel_key(key, location="user")
saved = True
save_location = "~/.stegasoo/channel.key"
elif save_project:
set_channel_key(key, location='project')
set_channel_key(key, location="project")
saved = True
save_location = "./config/channel.key"
@@ -626,11 +620,10 @@ async def api_channel_set(request: ChannelSetRequest):
"""
if not validate_channel_key(request.key):
raise HTTPException(
400,
"Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
400, "Invalid channel key format. Expected: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
)
if request.location not in ('user', 'project'):
if request.location not in ("user", "project"):
raise HTTPException(400, "location must be 'user' or 'project'")
set_channel_key(request.key, location=request.location)
@@ -638,8 +631,8 @@ async def api_channel_set(request: ChannelSetRequest):
status = get_channel_status()
return {
"success": True,
"location": status.get('source'),
"fingerprint": status.get('fingerprint'),
"location": status.get("source"),
"fingerprint": status.get("fingerprint"),
}
@@ -655,9 +648,9 @@ async def api_channel_clear(
Note: Does not affect environment variables.
"""
if location == "all":
clear_channel_key(location='user')
clear_channel_key(location='project')
elif location in ('user', 'project'):
clear_channel_key(location="user")
clear_channel_key(location="project")
elif location in ("user", "project"):
clear_channel_key(location=location)
else:
raise HTTPException(400, "location must be 'user', 'project', or 'all'")
@@ -665,9 +658,9 @@ async def api_channel_clear(
status = get_channel_status()
return {
"success": True,
"mode": status['mode'],
"still_configured": status['configured'],
"remaining_source": status.get('source'),
"mode": status["mode"],
"still_configured": status["configured"],
"remaining_source": status.get("source"),
}
@@ -684,28 +677,30 @@ async def api_compare_modes(request: CompareModesRequest):
comparison = compare_modes(carrier)
response = CompareModesResponse(
width=comparison['width'],
height=comparison['height'],
width=comparison["width"],
height=comparison["height"],
lsb={
"capacity_bytes": comparison['lsb']['capacity_bytes'],
"capacity_kb": round(comparison['lsb']['capacity_kb'], 1),
"capacity_bytes": comparison["lsb"]["capacity_bytes"],
"capacity_kb": round(comparison["lsb"]["capacity_kb"], 1),
"available": True,
"output_format": comparison['lsb']['output'],
"output_format": comparison["lsb"]["output"],
},
dct={
"capacity_bytes": comparison['dct']['capacity_bytes'],
"capacity_kb": round(comparison['dct']['capacity_kb'], 1),
"available": comparison['dct']['available'],
"capacity_bytes": comparison["dct"]["capacity_bytes"],
"capacity_kb": round(comparison["dct"]["capacity_kb"], 1),
"available": comparison["dct"]["available"],
"output_formats": ["png", "jpeg"],
"color_modes": ["grayscale", "color"],
"ratio_vs_lsb_percent": round(comparison['dct']['ratio_vs_lsb'], 1),
"ratio_vs_lsb_percent": round(comparison["dct"]["ratio_vs_lsb"], 1),
},
recommendation="lsb" if not comparison['dct']['available'] else "dct for stealth, lsb for capacity"
recommendation=(
"lsb" if not comparison["dct"]["available"] else "dct for stealth, lsb for capacity"
),
)
if request.payload_size:
fits_lsb = request.payload_size <= comparison['lsb']['capacity_bytes']
fits_dct = request.payload_size <= comparison['dct']['capacity_bytes']
fits_lsb = request.payload_size <= comparison["lsb"]["capacity_bytes"]
fits_dct = request.payload_size <= comparison["dct"]["capacity_bytes"]
response.payload_check = {
"size_bytes": request.payload_size,
@@ -714,7 +709,7 @@ async def api_compare_modes(request: CompareModesRequest):
}
# Update recommendation based on payload
if fits_dct and comparison['dct']['available']:
if fits_dct and comparison["dct"]["available"]:
response.recommendation = "dct (payload fits, better stealth)"
elif fits_lsb:
response.recommendation = "lsb (payload too large for dct)"
@@ -743,12 +738,12 @@ async def api_will_fit(request: WillFitRequest):
result = will_fit_by_mode(request.payload_size, carrier, embed_mode=request.embed_mode)
return WillFitResponse(
fits=result['fits'],
payload_size=result['payload_size'],
capacity=result['capacity'],
usage_percent=round(result['usage_percent'], 1),
headroom=result['headroom'],
mode=request.embed_mode
fits=result["fits"],
payload_size=result["payload_size"],
capacity=result["capacity"],
usage_percent=round(result["usage_percent"], 1),
headroom=result["headroom"],
mode=request.embed_mode,
)
except HTTPException:
@@ -761,6 +756,7 @@ async def api_will_fit(request: WillFitRequest):
# ROUTES - QR CODE
# ============================================================================
@app.post("/extract-key-from-qr", response_model=QrExtractResponse)
async def api_extract_key_from_qr(
qr_image: UploadFile = File(..., description="QR code image containing RSA key")
@@ -772,10 +768,7 @@ async def api_extract_key_from_qr(
Returns the PEM-encoded key if found.
"""
if not HAS_QR_READ:
raise HTTPException(
501,
"QR code reading not available. Install pyzbar and libzbar."
)
raise HTTPException(501, "QR code reading not available. Install pyzbar and libzbar.")
try:
image_data = await qr_image.read()
@@ -784,10 +777,7 @@ async def api_extract_key_from_qr(
if key_pem:
return QrExtractResponse(success=True, key_pem=key_pem)
else:
return QrExtractResponse(
success=False,
error="No valid RSA key found in QR code"
)
return QrExtractResponse(success=False, error="No valid RSA key found in QR code")
except Exception as e:
return QrExtractResponse(success=False, error=str(e))
@@ -796,6 +786,7 @@ async def api_extract_key_from_qr(
# ROUTES - GENERATE
# ============================================================================
@app.post("/generate", response_model=GenerateResponse)
async def api_generate(request: GenerateRequest):
"""
@@ -829,9 +820,9 @@ async def api_generate(request: GenerateRequest):
"passphrase": creds.passphrase_entropy,
"pin": creds.pin_entropy,
"rsa": creds.rsa_entropy,
"total": creds.total_entropy
"total": creds.total_entropy,
},
phrases=None # Legacy field removed
phrases=None, # Legacy field removed
)
except Exception as e:
raise HTTPException(500, str(e))
@@ -841,6 +832,7 @@ async def api_generate(request: GenerateRequest):
# HELPER FUNCTION FOR DCT PARAMETERS
# ============================================================================
def _get_dct_params(embed_mode: str, dct_output_format: str, dct_color_mode: str) -> dict:
"""
Get DCT-specific parameters if DCT mode is selected.
@@ -876,6 +868,7 @@ def _get_output_info(embed_mode: str, dct_output_format: str, dct_color_mode: st
# ROUTES - ENCODE (JSON)
# ============================================================================
@app.post("/encode", response_model=EncodeResponse)
async def api_encode(request: EncodeRequest):
"""
@@ -900,9 +893,7 @@ async def api_encode(request: EncodeRequest):
# Get DCT parameters
dct_params = _get_dct_params(
request.embed_mode,
request.dct_output_format,
request.dct_color_mode
request.embed_mode, request.dct_output_format, request.dct_color_mode
)
# v4.0.0: Include channel_key
@@ -919,12 +910,10 @@ async def api_encode(request: EncodeRequest):
**dct_params,
)
stego_b64 = base64.b64encode(result.stego_image).decode('utf-8')
stego_b64 = base64.b64encode(result.stego_image).decode("utf-8")
output_format, color_mode, _ = _get_output_info(
request.embed_mode,
request.dct_output_format,
request.dct_color_mode
request.embed_mode, request.dct_output_format, request.dct_color_mode
)
# Get channel info for response
@@ -975,16 +964,12 @@ async def api_encode_file(request: EncodeFileRequest):
rsa_key = base64.b64decode(request.rsa_key_base64) if request.rsa_key_base64 else None
payload = FilePayload(
data=file_data,
filename=request.filename,
mime_type=request.mime_type
data=file_data, filename=request.filename, mime_type=request.mime_type
)
# Get DCT parameters
dct_params = _get_dct_params(
request.embed_mode,
request.dct_output_format,
request.dct_color_mode
request.embed_mode, request.dct_output_format, request.dct_color_mode
)
# v4.0.0: Include channel_key
@@ -1001,12 +986,10 @@ async def api_encode_file(request: EncodeFileRequest):
**dct_params,
)
stego_b64 = base64.b64encode(result.stego_image).decode('utf-8')
stego_b64 = base64.b64encode(result.stego_image).decode("utf-8")
output_format, color_mode, _ = _get_output_info(
request.embed_mode,
request.dct_output_format,
request.dct_color_mode
request.embed_mode, request.dct_output_format, request.dct_color_mode
)
# Get channel info for response
@@ -1037,6 +1020,7 @@ async def api_encode_file(request: EncodeFileRequest):
# ROUTES - DECODE (JSON)
# ============================================================================
@app.post("/decode", response_model=DecodeResponse)
async def api_decode(request: DecodeRequest):
"""
@@ -1073,21 +1057,18 @@ async def api_decode(request: DecodeRequest):
if result.is_file:
return DecodeResponse(
payload_type='file',
file_data_base64=base64.b64encode(result.file_data).decode('utf-8'),
payload_type="file",
file_data_base64=base64.b64encode(result.file_data).decode("utf-8"),
filename=result.filename,
mime_type=result.mime_type
mime_type=result.mime_type,
)
else:
return DecodeResponse(
payload_type='text',
message=result.message
)
return DecodeResponse(payload_type="text", message=result.message)
except DecryptionError as e:
# Provide helpful error message for channel key issues
error_msg = str(e)
if 'channel key' in error_msg.lower():
if "channel key" in error_msg.lower():
raise HTTPException(401, error_msg)
raise HTTPException(401, "Decryption failed. Check credentials.")
except StegasooError as e:
@@ -1100,6 +1081,7 @@ async def api_decode(request: DecodeRequest):
# ROUTES - ENCODE/DECODE (MULTIPART)
# ============================================================================
@app.post("/encode/multipart")
async def api_encode_multipart(
passphrase: str = Form(..., description="Passphrase (v3.2.0: renamed from day_phrase)"),
@@ -1112,7 +1094,9 @@ async def api_encode_multipart(
rsa_key_qr: UploadFile | None = File(None),
rsa_password: str = Form(""),
# Channel key (v4.0.0)
channel_key: str = Form("auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"),
channel_key: str = Form(
"auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"
),
embed_mode: str = Form("lsb"),
dct_output_format: str = Form("png"),
dct_color_mode: str = Form("grayscale"),
@@ -1162,14 +1146,13 @@ async def api_encode_multipart(
elif rsa_key_qr and rsa_key_qr.filename:
if not HAS_QR_READ:
raise HTTPException(
501,
"QR code reading not available. Install pyzbar and libzbar."
501, "QR code reading not available. Install pyzbar and libzbar."
)
qr_image_data = await rsa_key_qr.read()
key_pem = extract_key_from_qr(qr_image_data)
if not key_pem:
raise HTTPException(400, "Could not extract RSA key from QR code image")
rsa_key_data = key_pem.encode('utf-8')
rsa_key_data = key_pem.encode("utf-8")
rsa_key_from_qr = True
# QR code keys are never password-protected
@@ -1179,9 +1162,7 @@ async def api_encode_multipart(
if payload_file and payload_file.filename:
file_data = await payload_file.read()
payload = FilePayload(
data=file_data,
filename=payload_file.filename,
mime_type=payload_file.content_type
data=file_data, filename=payload_file.filename, mime_type=payload_file.content_type
)
elif message:
payload = message
@@ -1251,7 +1232,9 @@ async def api_decode_multipart(
rsa_key_qr: UploadFile | None = File(None),
rsa_password: str = Form(""),
# Channel key (v4.0.0)
channel_key: str = Form("auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"),
channel_key: str = Form(
"auto", description="Channel key: 'auto'=server config, 'none'=public, 'XXXX-...'=explicit"
),
embed_mode: str = Form("auto"),
):
"""
@@ -1291,14 +1274,13 @@ async def api_decode_multipart(
elif rsa_key_qr and rsa_key_qr.filename:
if not HAS_QR_READ:
raise HTTPException(
501,
"QR code reading not available. Install pyzbar and libzbar."
501, "QR code reading not available. Install pyzbar and libzbar."
)
qr_image_data = await rsa_key_qr.read()
key_pem = extract_key_from_qr(qr_image_data)
if not key_pem:
raise HTTPException(400, "Could not extract RSA key from QR code image")
rsa_key_data = key_pem.encode('utf-8')
rsa_key_data = key_pem.encode("utf-8")
rsa_key_from_qr = True
# QR code keys are never password-protected
@@ -1318,20 +1300,17 @@ async def api_decode_multipart(
if result.is_file:
return DecodeResponse(
payload_type='file',
file_data_base64=base64.b64encode(result.file_data).decode('utf-8'),
payload_type="file",
file_data_base64=base64.b64encode(result.file_data).decode("utf-8"),
filename=result.filename,
mime_type=result.mime_type
mime_type=result.mime_type,
)
else:
return DecodeResponse(
payload_type='text',
message=result.message
)
return DecodeResponse(payload_type="text", message=result.message)
except DecryptionError as e:
error_msg = str(e)
if 'channel key' in error_msg.lower():
if "channel key" in error_msg.lower():
raise HTTPException(401, error_msg)
raise HTTPException(401, "Decryption failed. Check credentials.")
except StegasooError as e:
@@ -1346,10 +1325,11 @@ async def api_decode_multipart(
# ROUTES - IMAGE INFO
# ============================================================================
@app.post("/image/info", response_model=ImageInfoResponse)
async def api_image_info(
image: UploadFile = File(...),
include_modes: bool = Query(True, description="Include capacity by mode (v3.0+)")
include_modes: bool = Query(True, description="Include capacity by mode (v3.0+)"),
):
"""
Get information about an image's capacity.
@@ -1363,29 +1343,29 @@ async def api_image_info(
if not result.is_valid:
raise HTTPException(400, result.error_message)
capacity = calculate_capacity_by_mode(image_data, 'lsb')
capacity = calculate_capacity_by_mode(image_data, "lsb")
response = ImageInfoResponse(
width=result.details['width'],
height=result.details['height'],
pixels=result.details['pixels'],
width=result.details["width"],
height=result.details["height"],
pixels=result.details["pixels"],
capacity_bytes=capacity,
capacity_kb=capacity // 1024
capacity_kb=capacity // 1024,
)
if include_modes:
comparison = compare_modes(image_data)
response.modes = {
"lsb": ModeCapacity(
capacity_bytes=comparison['lsb']['capacity_bytes'],
capacity_kb=round(comparison['lsb']['capacity_kb'], 1),
capacity_bytes=comparison["lsb"]["capacity_bytes"],
capacity_kb=round(comparison["lsb"]["capacity_kb"], 1),
available=True,
output_format=comparison['lsb']['output'],
output_format=comparison["lsb"]["output"],
),
"dct": ModeCapacity(
capacity_bytes=comparison['dct']['capacity_bytes'],
capacity_kb=round(comparison['dct']['capacity_kb'], 1),
available=comparison['dct']['available'],
capacity_bytes=comparison["dct"]["capacity_bytes"],
capacity_kb=round(comparison["dct"]["capacity_kb"], 1),
available=comparison["dct"]["available"],
output_format="PNG/JPEG (grayscale or color)",
),
}
@@ -1402,18 +1382,17 @@ async def api_image_info(
# ERROR HANDLERS
# ============================================================================
@app.exception_handler(StegasooError)
async def stegasoo_error_handler(request, exc):
return JSONResponse(
status_code=400,
content={"error": type(exc).__name__, "detail": str(exc)}
)
return JSONResponse(status_code=400, content={"error": type(exc).__name__, "detail": str(exc)})
# ============================================================================
# MAIN
# ============================================================================
if __name__ == '__main__':
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)