diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000..1388f5a --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,5 @@ +{ + "enabledPlugins": { + "church@church": true + } +} diff --git a/agentstuff/pyproject.toml b/agentstuff/pyproject.toml new file mode 100644 index 0000000..71ab958 --- /dev/null +++ b/agentstuff/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "sentiment-agent" +version = "0.1.0" +description = "AI agent for gathering data and performing sentiment analysis" +requires-python = ">=3.11" +dependencies = [ + "claude-agent-sdk", + "anyio", + "httpx", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", + "mypy", +] + +[project.scripts] +sentiment-agent = "sentiment_agent.main:main" + +[tool.ruff] +line-length = 100 + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.mypy] +python_version = "3.11" +ignore_missing_imports = true diff --git a/agentstuff/sentiment_agent/__init__.py b/agentstuff/sentiment_agent/__init__.py new file mode 100644 index 0000000..3301b88 --- /dev/null +++ b/agentstuff/sentiment_agent/__init__.py @@ -0,0 +1,3 @@ +"""Sentiment analysis agent powered by Claude Agent SDK.""" + +__version__ = "0.1.0" diff --git a/agentstuff/sentiment_agent/agent.py b/agentstuff/sentiment_agent/agent.py new file mode 100644 index 0000000..bbc4d81 --- /dev/null +++ b/agentstuff/sentiment_agent/agent.py @@ -0,0 +1,115 @@ +"""Core sentiment analysis agent using Claude Agent SDK.""" + +from __future__ import annotations + +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ClaudeSDKClient, + ResultMessage, + TextBlock, +) + +from sentiment_agent.config import SafetyConfig +from sentiment_agent.tools import create_social_tools_server + +SYSTEM_PROMPT = """\ +You are a sentiment analysis agent. Your job is to gather data from multiple \ +platforms and produce a structured, evidence-based sentiment report. + +## Rules — you MUST follow these + +1. **Budget awareness.** You have a limited API call budget. Call \ +`get_api_budget_status` before starting and after every few tool calls. \ +Stop gathering data when you have <5 calls remaining and begin your analysis. + +2. **Credibility first.** Every tool result includes credibility scores and \ +bot/disinfo flags. You MUST: + - NEVER quote or cite posts marked `likely_inauthentic` (score < 0.3). + - Flag posts marked `suspicious` (score 0.3–0.5) with a warning when citing them. + - Give more weight to `likely_authentic` posts (score ≥ 0.7). + - If coordination warnings appear (copy-paste campaigns, burst posting), \ + call them out prominently in your report. + +3. **Platform diversity.** Gather from at least 2 different platforms before \ +analyzing. Do not over-index on a single source. + +4. **No fabrication.** Only report on data you actually retrieved. If a tool \ +call fails or returns no results, say so — do not invent data. + +5. **Structured output.** Your final report MUST include these sections: + - **Data Quality Summary**: platforms queried, posts analyzed vs excluded, \ + coordination warnings + - **Overall Sentiment**: score (-1.0 to +1.0) and label \ + (very negative / negative / mixed / neutral / positive / very positive) + - **Platform Breakdown**: sentiment per platform with sample size + - **Key Themes**: top 3-5 themes with sentiment direction + - **Credibility Concerns**: any bot networks, disinfo patterns, or \ + coordinated campaigns detected + - **Notable Quotes**: 3-5 representative quotes (authentic sources only, \ + with credibility score noted) + - **Confidence Assessment**: how confident you are in the analysis given \ + data quality and volume + +6. **Scope discipline.** Stay focused on the requested topic. Do not expand \ +scope, follow tangents, or analyze adjacent topics unless explicitly asked. + +7. **No side effects.** Do not write files, run commands, or take any action \ +beyond reading data and producing your report. +""" + + +async def run_sentiment_analysis( + topic: str, + sources: list[str] | None = None, + config: SafetyConfig | None = None, +) -> str: + """Run the sentiment analysis agent on a given topic. + + Args: + topic: The topic or subject to analyze sentiment for. + sources: Optional list of URLs or data sources to analyze. + config: Safety configuration. Defaults to SafetyConfig.from_env(). + + Returns: + The agent's sentiment analysis report. + """ + config = config or SafetyConfig.from_env() + + source_instructions = "" + if sources: + source_list = "\n".join(f"- {s}" for s in sources) + source_instructions = f"\n\nAlso analyze these specific sources:\n{source_list}" + + prompt = ( + f"Perform a sentiment analysis on the following topic: {topic}\n\n" + "Start by calling `get_api_budget_status` to check your budget, then " + "gather data from multiple platforms (Reddit, Hacker News, Bluesky if " + "configured, and web search). Pay close attention to credibility scores " + "and coordination warnings in the results." + f"{source_instructions}" + ) + + social_server = create_social_tools_server(config) + + options = ClaudeAgentOptions( + # Only allow read-only tools — no Write/Bash to prevent side effects + allowed_tools=["WebSearch", "WebFetch", "Read"], + max_turns=config.max_turns, + max_budget_usd=config.max_budget_usd, + mcp_servers={"social": social_server}, + system_prompt=SYSTEM_PROMPT, + ) + + result_text = "" + async with ClaudeSDKClient(options=options) as client: + await client.query(prompt) + async for message in client.receive_response(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(block.text, end="", flush=True) + if isinstance(message, ResultMessage): + result_text = message.result + + return result_text diff --git a/agentstuff/sentiment_agent/clients/__init__.py b/agentstuff/sentiment_agent/clients/__init__.py new file mode 100644 index 0000000..2c3b142 --- /dev/null +++ b/agentstuff/sentiment_agent/clients/__init__.py @@ -0,0 +1 @@ +"""API clients for social media and forum data sources.""" diff --git a/agentstuff/sentiment_agent/clients/bluesky.py b/agentstuff/sentiment_agent/clients/bluesky.py new file mode 100644 index 0000000..fa50a9f --- /dev/null +++ b/agentstuff/sentiment_agent/clients/bluesky.py @@ -0,0 +1,166 @@ +"""Bluesky client using the AT Protocol API. + +Search requires authentication. Set BLUESKY_HANDLE and BLUESKY_APP_PASSWORD +env vars. Create an app password at: https://bsky.app/settings/app-passwords + +Thread fetching works without auth via the public API. +""" + +import os +import httpx + +BSKY_PUBLIC_API = "https://public.api.bsky.app" +BSKY_AUTH_API = "https://bsky.social" + + +async def _get_session() -> dict | None: + """Authenticate with Bluesky and return session tokens, or None if no creds.""" + handle = os.environ.get("BLUESKY_HANDLE") + app_password = os.environ.get("BLUESKY_APP_PASSWORD") + if not handle or not app_password: + return None + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.post( + f"{BSKY_AUTH_API}/xrpc/com.atproto.server.createSession", + json={"identifier": handle, "password": app_password}, + ) + resp.raise_for_status() + return resp.json() + + +def _format_post(post_view: dict) -> dict: + """Extract relevant fields from an AT Protocol post view.""" + post = post_view.get("post", post_view) + record = post.get("record", {}) + author = post.get("author", {}) + return { + "text": record.get("text", ""), + "author_handle": author.get("handle", ""), + "author_display_name": author.get("displayName", ""), + "created_at": record.get("createdAt", ""), + "like_count": post.get("likeCount", 0), + "repost_count": post.get("repostCount", 0), + "reply_count": post.get("replyCount", 0), + "uri": post.get("uri", ""), + "cid": post.get("cid", ""), + "url": _uri_to_url(post.get("uri", ""), author.get("handle", "")), + } + + +def _uri_to_url(uri: str, handle: str) -> str: + """Convert an at:// URI to a bsky.app URL.""" + # at://did:plc:xxx/app.bsky.feed.post/rkey -> https://bsky.app/profile/handle/post/rkey + if not uri.startswith("at://"): + return "" + parts = uri.split("/") + if len(parts) >= 5: + rkey = parts[-1] + return f"https://bsky.app/profile/{handle}/post/{rkey}" + return "" + + +async def search_posts(query: str, limit: int = 25, sort: str = "top") -> list[dict]: + """Search Bluesky for posts matching a query. + + Requires BLUESKY_HANDLE and BLUESKY_APP_PASSWORD env vars. + + Args: + query: Search terms. + limit: Max results (capped at 100). + sort: "top" (most liked) or "latest" (chronological). + + Returns: + List of post dicts with: text, author_handle, author_display_name, + created_at, like_count, repost_count, reply_count, uri, url. + + Raises: + RuntimeError: If Bluesky credentials are not configured. + """ + session = await _get_session() + if not session: + raise RuntimeError( + "Bluesky search requires authentication. " + "Set BLUESKY_HANDLE and BLUESKY_APP_PASSWORD environment variables. " + "Create an app password at: https://bsky.app/settings/app-passwords" + ) + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get( + f"{BSKY_AUTH_API}/xrpc/app.bsky.feed.searchPosts", + params={ + "q": query, + "limit": min(limit, 100), + "sort": sort, + }, + headers={"Authorization": f"Bearer {session['accessJwt']}"}, + ) + resp.raise_for_status() + data = resp.json() + + return [_format_post(p) for p in data.get("posts", [])] + + +async def get_thread(uri: str, depth: int = 6) -> dict: + """Fetch a Bluesky thread by AT URI or bsky.app URL. + + Args: + uri: Either an at:// URI or a https://bsky.app/profile/.../post/... URL. + depth: How many levels of replies to fetch (max 1000). + + Returns: + Dict with "post" (the root post) and "replies" (list of reply post dicts). + """ + # Convert bsky.app URL to AT URI if needed + if uri.startswith("https://bsky.app/"): + uri = await _resolve_url_to_uri(uri) + + headers = {} + session = await _get_session() + if session: + headers["Authorization"] = f"Bearer {session['accessJwt']}" + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get( + f"{BSKY_PUBLIC_API}/xrpc/app.bsky.feed.getPostThread", + params={"uri": uri, "depth": min(depth, 1000)}, + headers=headers, + ) + resp.raise_for_status() + data = resp.json() + + thread = data.get("thread", {}) + root_post = _format_post(thread) if "post" in thread else {} + + replies = [] + for reply in thread.get("replies", []): + if "post" in reply: + replies.append(_format_post(reply)) + # Include nested replies one level deep + for nested in reply.get("replies", []): + if "post" in nested: + replies.append(_format_post(nested)) + + return {"post": root_post, "replies": replies} + + +async def _resolve_url_to_uri(url: str) -> str: + """Convert a bsky.app URL to an AT URI by resolving the handle.""" + # https://bsky.app/profile/handle.bsky.social/post/rkey + parts = url.rstrip("/").split("/") + if len(parts) < 6: + raise ValueError(f"Invalid Bluesky URL: {url}") + + handle = parts[4] # profile/{handle} + rkey = parts[6] # post/{rkey} + + # Resolve handle to DID + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + f"{BSKY_PUBLIC_API}/xrpc/com.atproto.identity.resolveHandle", + params={"handle": handle}, + ) + resp.raise_for_status() + did = resp.json()["did"] + + return f"at://{did}/app.bsky.feed.post/{rkey}" diff --git a/agentstuff/sentiment_agent/clients/hackernews.py b/agentstuff/sentiment_agent/clients/hackernews.py new file mode 100644 index 0000000..350107f --- /dev/null +++ b/agentstuff/sentiment_agent/clients/hackernews.py @@ -0,0 +1,78 @@ +"""Hacker News client using the Algolia HN Search API. + +No authentication required. Docs: https://hn.algolia.com/api +""" + +import httpx + +HN_API_BASE = "https://hn.algolia.com/api/v1" + + +async def search_stories(query: str, limit: int = 25) -> list[dict]: + """Search HN for stories matching a query. + + Returns a list of story dicts with: title, url, author, points, + num_comments, created_at, objectID, story_text. + """ + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get( + f"{HN_API_BASE}/search", + params={ + "query": query, + "tags": "story", + "hitsPerPage": min(limit, 50), + }, + ) + resp.raise_for_status() + data = resp.json() + + results = [] + for hit in data.get("hits", []): + results.append( + { + "title": hit.get("title", ""), + "url": hit.get("url", ""), + "author": hit.get("author", ""), + "points": hit.get("points", 0), + "num_comments": hit.get("num_comments", 0), + "created_at": hit.get("created_at", ""), + "object_id": hit.get("objectID", ""), + "story_text": hit.get("story_text") or "", + "hn_url": f"https://news.ycombinator.com/item?id={hit.get('objectID', '')}", + } + ) + return results + + +async def search_comments(query: str, limit: int = 25) -> list[dict]: + """Search HN for comments matching a query. + + Returns a list of comment dicts with: comment_text, author, points, + created_at, story_title, story_url. + """ + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get( + f"{HN_API_BASE}/search", + params={ + "query": query, + "tags": "comment", + "hitsPerPage": min(limit, 50), + }, + ) + resp.raise_for_status() + data = resp.json() + + results = [] + for hit in data.get("hits", []): + results.append( + { + "comment_text": hit.get("comment_text", ""), + "author": hit.get("author", ""), + "points": hit.get("points", 0), + "created_at": hit.get("created_at", ""), + "story_title": hit.get("story_title", ""), + "story_url": hit.get("story_url", ""), + "hn_url": f"https://news.ycombinator.com/item?id={hit.get('objectID', '')}", + } + ) + return results diff --git a/agentstuff/sentiment_agent/clients/reddit.py b/agentstuff/sentiment_agent/clients/reddit.py new file mode 100644 index 0000000..6a9d129 --- /dev/null +++ b/agentstuff/sentiment_agent/clients/reddit.py @@ -0,0 +1,117 @@ +"""Reddit client using the public JSON API. + +No authentication required for read-only search. Reddit requires a descriptive +User-Agent header — requests with generic UAs get 429'd. +""" + +import httpx + +REDDIT_BASE = "https://www.reddit.com" +USER_AGENT = "sentiment-agent/0.1.0 (research; sentiment analysis tool)" + + +async def search_posts( + query: str, + subreddit: str = "all", + sort: str = "relevance", + time_filter: str = "month", + limit: int = 25, +) -> list[dict]: + """Search Reddit for posts matching a query. + + Args: + query: Search terms. + subreddit: Subreddit to search within, or "all" for site-wide. + sort: One of "relevance", "hot", "top", "new", "comments". + time_filter: One of "hour", "day", "week", "month", "year", "all". + limit: Max results (capped at 100 by Reddit). + + Returns: + List of post dicts with: title, selftext, author, score, + num_comments, subreddit, url, permalink, created_utc. + """ + url = f"{REDDIT_BASE}/r/{subreddit}/search.json" + async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client: + resp = await client.get( + url, + params={ + "q": query, + "sort": sort, + "t": time_filter, + "limit": min(limit, 100), + "restrict_sr": "on" if subreddit != "all" else "off", + }, + headers={"User-Agent": USER_AGENT}, + ) + resp.raise_for_status() + data = resp.json() + + results = [] + for child in data.get("data", {}).get("children", []): + post = child.get("data", {}) + results.append( + { + "title": post.get("title", ""), + "selftext": (post.get("selftext") or "")[:2000], + "author": post.get("author", "[deleted]"), + "score": post.get("score", 0), + "upvote_ratio": post.get("upvote_ratio", 0), + "num_comments": post.get("num_comments", 0), + "subreddit": post.get("subreddit", ""), + "url": post.get("url", ""), + "permalink": f"https://reddit.com{post.get('permalink', '')}", + "created_utc": post.get("created_utc", 0), + "is_self": post.get("is_self", False), + } + ) + return results + + +async def get_post_comments( + permalink: str, + sort: str = "top", + limit: int = 25, +) -> list[dict]: + """Fetch top-level comments for a Reddit post. + + Args: + permalink: The post's permalink path (e.g., "/r/python/comments/abc123/title/"). + sort: Comment sort order: "top", "best", "new", "controversial". + limit: Max comments to return. + + Returns: + List of comment dicts with: body, author, score, created_utc. + """ + # Strip domain if full URL was passed + if permalink.startswith("https://"): + permalink = permalink.replace("https://reddit.com", "") + permalink = permalink.replace("https://www.reddit.com", "") + + url = f"{REDDIT_BASE}{permalink}.json" + async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client: + resp = await client.get( + url, + params={"sort": sort, "limit": limit}, + headers={"User-Agent": USER_AGENT}, + ) + resp.raise_for_status() + data = resp.json() + + # Reddit returns [post_listing, comments_listing] + if not isinstance(data, list) or len(data) < 2: + return [] + + results = [] + for child in data[1].get("data", {}).get("children", []): + if child.get("kind") != "t1": + continue + comment = child.get("data", {}) + results.append( + { + "body": (comment.get("body") or "")[:2000], + "author": comment.get("author", "[deleted]"), + "score": comment.get("score", 0), + "created_utc": comment.get("created_utc", 0), + } + ) + return results diff --git a/agentstuff/sentiment_agent/config.py b/agentstuff/sentiment_agent/config.py new file mode 100644 index 0000000..d12370d --- /dev/null +++ b/agentstuff/sentiment_agent/config.py @@ -0,0 +1,70 @@ +"""Configuration and safety limits for the sentiment agent. + +All guardrails are centralized here so they can be tuned from one place +or overridden via CLI flags / env vars. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class RateLimitConfig: + """Per-platform rate limiting.""" + + requests_per_minute: int = 10 + burst_size: int = 3 # max concurrent requests + cooldown_after_429: float = 30.0 # seconds to wait after a 429 + + +@dataclass(frozen=True) +class SafetyConfig: + """Top-level safety rails for the agent.""" + + # --- Agent-level limits --- + max_turns: int = 20 + max_budget_usd: float = 0.50 # hard cap on Claude API spend per run + max_total_api_calls: int = 50 # across ALL platforms combined + max_results_per_call: int = 50 # cap the `limit` param sent to any API + + # --- Per-platform rate limits --- + bluesky_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig( + requests_per_minute=10, burst_size=2, + )) + reddit_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig( + requests_per_minute=10, burst_size=2, + )) + hackernews_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig( + requests_per_minute=15, burst_size=3, # HN Algolia is more generous + )) + + # --- Content size limits --- + max_post_text_chars: int = 2000 # truncate individual posts beyond this + max_total_content_bytes: int = 500_000 # ~500KB total data gathered before agent stops + + # --- Timeout --- + api_timeout_seconds: float = 15.0 + + # --- Credibility thresholds --- + min_credibility_score: float = 0.3 # posts below this are flagged/excluded + flag_bot_threshold: float = 0.5 # posts between min and this are flagged but included + + @classmethod + def from_env(cls) -> SafetyConfig: + """Build config with env var overrides. + + Env vars: SENTIMENT_MAX_TURNS, SENTIMENT_MAX_BUDGET_USD, + SENTIMENT_MAX_API_CALLS, SENTIMENT_MIN_CREDIBILITY. + """ + kwargs: dict = {} + if v := os.environ.get("SENTIMENT_MAX_TURNS"): + kwargs["max_turns"] = int(v) + if v := os.environ.get("SENTIMENT_MAX_BUDGET_USD"): + kwargs["max_budget_usd"] = float(v) + if v := os.environ.get("SENTIMENT_MAX_API_CALLS"): + kwargs["max_total_api_calls"] = int(v) + if v := os.environ.get("SENTIMENT_MIN_CREDIBILITY"): + kwargs["min_credibility_score"] = float(v) + return cls(**kwargs) diff --git a/agentstuff/sentiment_agent/credibility.py b/agentstuff/sentiment_agent/credibility.py new file mode 100644 index 0000000..574ec1f --- /dev/null +++ b/agentstuff/sentiment_agent/credibility.py @@ -0,0 +1,398 @@ +"""Credibility scoring and bot/disinfo detection. + +Assigns a 0.0–1.0 credibility score to each post based on heuristic signals. +Posts below the configured threshold are excluded or flagged so they don't +pollute the sentiment analysis. + +Signals are platform-aware — each platform has different indicators of +inauthentic behavior. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone + + +@dataclass +class CredibilityResult: + """Credibility assessment for a single post.""" + + score: float # 0.0 (likely bot/disinfo) to 1.0 (likely authentic) + flags: list[str] = field(default_factory=list) # human-readable reasons + is_excluded: bool = False # below min_credibility_score + is_flagged: bool = False # between min and flag threshold + + @property + def label(self) -> str: + if self.score >= 0.7: + return "likely_authentic" + if self.score >= 0.5: + return "uncertain" + if self.score >= 0.3: + return "suspicious" + return "likely_inauthentic" + + +# --- Shared heuristics --- + +# Common bot patterns in text +_BOT_TEXT_PATTERNS = [ + # Crypto/scam spam + re.compile(r"(?i)(dm me|check my bio|link in bio|click here|free giveaway)"), + re.compile(r"(?i)(join my|subscribe to|follow me for|🔥.*🔥.*🔥)"), + # Astroturfing phrases + re.compile(r"(?i)(i (just )?(discovered|found|tried) this (amazing|incredible|awesome))"), + re.compile(r"(?i)(game.?changer|life.?changing|you won'?t believe)"), + # Excessive hashtags (5+) + re.compile(r"(#\w+\s*){5,}"), + # Walls of emojis (10+ consecutive) + re.compile(r"[\U0001F300-\U0001FAFF]{10,}"), + # Repetitive characters (spammy emphasis) + re.compile(r"(.)\1{9,}"), +] + +# Coordinated campaign indicators: identical or near-identical text +# This is checked at the batch level, not per-post + + +def _check_text_patterns(text: str) -> list[str]: + """Check text against common bot/spam patterns.""" + flags = [] + for pattern in _BOT_TEXT_PATTERNS: + if pattern.search(text): + flags.append(f"bot_text_pattern: {pattern.pattern[:60]}") + if len(text) < 15: + flags.append("very_short_text") + return flags + + +def _engagement_ratio_score( + likes: int, reposts: int, replies: int +) -> tuple[float, list[str]]: + """Score based on engagement ratios. + + Authentic posts tend to have a mix of likes, replies, and reposts. + Bot-amplified posts often have inflated likes with very few replies, + or massive repost counts with no discussion. + """ + flags = [] + total = likes + reposts + replies + + if total == 0: + return 0.5, ["no_engagement"] + + # High repost-to-reply ratio suggests amplification without discussion + if reposts > 0 and replies == 0 and reposts > 10: + flags.append(f"high_repost_no_replies: {reposts} reposts, 0 replies") + return 0.3, flags + + # Extremely high like count with zero replies is suspicious + if likes > 100 and replies == 0: + flags.append(f"high_likes_no_replies: {likes} likes, 0 replies") + return 0.4, flags + + # Normal engagement + return min(1.0, 0.5 + (replies / max(total, 1)) * 0.5), flags + + +# --- Platform-specific scoring --- + + +def score_bluesky_post(post: dict) -> CredibilityResult: + """Score a Bluesky post for credibility.""" + score = 1.0 + flags: list[str] = [] + + text = post.get("text", "") + handle = post.get("author_handle", "") + display_name = post.get("author_display_name", "") + likes = post.get("like_count", 0) + reposts = post.get("repost_count", 0) + replies = post.get("reply_count", 0) + + # Text pattern checks + text_flags = _check_text_patterns(text) + if text_flags: + score -= 0.15 * len(text_flags) + flags.extend(text_flags) + + # Handle heuristics + # Randomly generated handles (long hex/number strings) + if re.match(r"^[a-f0-9]{8,}\.", handle): + flags.append(f"random_handle: {handle}") + score -= 0.3 + + # No display name set + if not display_name or display_name == handle: + flags.append("no_display_name") + score -= 0.1 + + # Engagement ratio + eng_score, eng_flags = _engagement_ratio_score(likes, reposts, replies) + flags.extend(eng_flags) + score = score * 0.6 + eng_score * 0.4 + + return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags) + + +def score_reddit_post(post: dict) -> CredibilityResult: + """Score a Reddit post for credibility.""" + score = 1.0 + flags: list[str] = [] + + text = post.get("selftext", "") or post.get("title", "") + author = post.get("author", "") + upvote_ratio = post.get("upvote_ratio", 0.5) + post_score = post.get("score", 0) + num_comments = post.get("num_comments", 0) + + # Text patterns + text_flags = _check_text_patterns(text) + if text_flags: + score -= 0.15 * len(text_flags) + flags.extend(text_flags) + + # Deleted author + if author in ("[deleted]", "[removed]"): + flags.append("deleted_author") + score -= 0.2 + + # Suspicious username patterns (random alphanumeric + numbers) + if re.match(r"^[A-Za-z]+[-_]?\d{4,}$", author): + flags.append(f"auto_generated_username: {author}") + score -= 0.15 + + # Very controversial ratio (lots of up AND down votes) + if upvote_ratio < 0.4 and post_score > 0: + flags.append(f"highly_controversial: {upvote_ratio:.0%} upvote ratio") + score -= 0.1 + + # High score but zero comments = potential vote manipulation + if post_score > 100 and num_comments == 0: + flags.append(f"high_score_no_comments: {post_score} score, 0 comments") + score -= 0.2 + + # Low-effort cross-post spam: very short title, external link, no selftext + if ( + len(post.get("title", "")) < 20 + and not post.get("is_self", True) + and not post.get("selftext") + ): + flags.append("possible_link_spam") + score -= 0.1 + + return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags) + + +def score_reddit_comment(comment: dict) -> CredibilityResult: + """Score a Reddit comment for credibility.""" + score = 1.0 + flags: list[str] = [] + + body = comment.get("body", "") + author = comment.get("author", "") + comment_score = comment.get("score", 0) + + text_flags = _check_text_patterns(body) + if text_flags: + score -= 0.15 * len(text_flags) + flags.extend(text_flags) + + if author in ("[deleted]", "[removed]"): + flags.append("deleted_author") + score -= 0.2 + + if re.match(r"^[A-Za-z]+[-_]?\d{4,}$", author): + flags.append(f"auto_generated_username: {author}") + score -= 0.15 + + # Heavily downvoted + if comment_score < -5: + flags.append(f"heavily_downvoted: {comment_score}") + score -= 0.15 + + return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags) + + +def score_hackernews_post(post: dict) -> CredibilityResult: + """Score a HN story for credibility. + + HN is generally higher-signal than social media, but we still check + for low-effort submissions and spammy patterns. + """ + score = 1.0 + flags: list[str] = [] + + title = post.get("title", "") + text = post.get("story_text", "") or title + points = post.get("points", 0) + num_comments = post.get("num_comments", 0) + + text_flags = _check_text_patterns(text) + if text_flags: + score -= 0.1 * len(text_flags) + flags.extend(text_flags) + + # Zero points = the community didn't find it valuable + if points == 0: + flags.append("zero_points") + score -= 0.1 + + # HN is generally more credible, start with a bonus + score = min(1.0, score + 0.1) + + return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags) + + +def score_hackernews_comment(comment: dict) -> CredibilityResult: + """Score a HN comment for credibility.""" + score = 1.0 + flags: list[str] = [] + + text = comment.get("comment_text", "") + + text_flags = _check_text_patterns(text) + if text_flags: + score -= 0.1 * len(text_flags) + flags.extend(text_flags) + + # HN comments are generally higher quality + score = min(1.0, score + 0.1) + + return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags) + + +# --- Batch-level coordination detection --- + + +def detect_coordination(posts: list[dict], text_key: str = "text") -> list[str]: + """Detect coordinated inauthentic behavior across a batch of posts. + + Looks for: + - Duplicate or near-duplicate text (copy-paste campaigns) + - Burst posting (many posts in a very short window) + - Same talking points with minor variations + + Returns a list of warning strings. + """ + warnings: list[str] = [] + texts = [p.get(text_key, "") for p in posts if p.get(text_key)] + + if not texts: + return warnings + + # Exact duplicates + seen: dict[str, int] = {} + for t in texts: + normalized = t.strip().lower() + seen[normalized] = seen.get(normalized, 0) + 1 + + duplicates = {text: count for text, count in seen.items() if count > 1} + if duplicates: + total_dupes = sum(duplicates.values()) + warnings.append( + f"COORDINATION WARNING: {len(duplicates)} duplicate texts found " + f"({total_dupes} total copies). Possible copy-paste campaign." + ) + + # Near-duplicates: check if many posts share a long common substring + # (simplified: check if >30% of posts start with the same 50+ chars) + if len(texts) >= 5: + prefixes: dict[str, int] = {} + for t in texts: + prefix = t.strip().lower()[:80] + if len(prefix) >= 50: + prefixes[prefix] = prefixes.get(prefix, 0) + 1 + + for prefix, count in prefixes.items(): + if count >= len(texts) * 0.3: + warnings.append( + f"COORDINATION WARNING: {count}/{len(texts)} posts share " + f"a common prefix ({prefix[:50]}...). Possible template campaign." + ) + + # Burst detection: if timestamps are available + timestamps = [] + for p in posts: + created = p.get("created_at") or p.get("created_utc") + if isinstance(created, str): + try: + timestamps.append(datetime.fromisoformat(created.replace("Z", "+00:00"))) + except (ValueError, TypeError): + pass + elif isinstance(created, (int, float)): + timestamps.append(datetime.fromtimestamp(created, tz=timezone.utc)) + + if len(timestamps) >= 5: + timestamps.sort() + # Check if >50% of posts landed within a 5-minute window + window_seconds = 300 + for i in range(len(timestamps) - 2): + window_end = timestamps[i] + __import__("datetime").timedelta(seconds=window_seconds) + in_window = sum(1 for t in timestamps if timestamps[i] <= t <= window_end) + if in_window >= len(timestamps) * 0.5: + warnings.append( + f"COORDINATION WARNING: {in_window}/{len(timestamps)} posts " + f"appeared within a 5-minute window. Possible coordinated posting." + ) + break + + return warnings + + +def filter_and_annotate( + posts: list[dict], + scorer, + min_score: float = 0.3, + flag_threshold: float = 0.5, +) -> tuple[list[dict], dict]: + """Score all posts, filter out low-credibility ones, and annotate the rest. + + Args: + posts: List of post dicts from any platform. + scorer: A scoring function (e.g., score_reddit_post). + min_score: Posts below this are excluded. + flag_threshold: Posts between min_score and this are flagged. + + Returns: + Tuple of (filtered_posts, stats_dict). + Each post in filtered_posts gets a "_credibility" key added. + """ + filtered = [] + stats = { + "total": len(posts), + "excluded": 0, + "flagged": 0, + "authentic": 0, + "excluded_reasons": [], + } + + for post in posts: + result = scorer(post) + result.is_excluded = result.score < min_score + result.is_flagged = min_score <= result.score < flag_threshold + + if result.is_excluded: + stats["excluded"] += 1 + stats["excluded_reasons"].append( + {"score": round(result.score, 2), "flags": result.flags} + ) + continue + + post["_credibility"] = { + "score": round(result.score, 2), + "label": result.label, + "flags": result.flags, + "is_flagged": result.is_flagged, + } + + if result.is_flagged: + stats["flagged"] += 1 + else: + stats["authentic"] += 1 + + filtered.append(post) + + return filtered, stats diff --git a/agentstuff/sentiment_agent/main.py b/agentstuff/sentiment_agent/main.py new file mode 100644 index 0000000..2e4a178 --- /dev/null +++ b/agentstuff/sentiment_agent/main.py @@ -0,0 +1,66 @@ +"""CLI entry point for the sentiment analysis agent.""" + +import argparse +import anyio + +from sentiment_agent.agent import run_sentiment_analysis +from sentiment_agent.config import SafetyConfig + + +async def async_main(args: argparse.Namespace) -> None: + config = SafetyConfig( + max_turns=args.max_turns, + max_budget_usd=args.max_budget, + max_total_api_calls=args.max_api_calls, + min_credibility_score=args.min_credibility, + flag_bot_threshold=args.flag_threshold, + ) + + result = await run_sentiment_analysis( + topic=args.topic, + sources=args.sources, + config=config, + ) + print("\n" + "=" * 60) + print("SENTIMENT ANALYSIS REPORT") + print("=" * 60) + print(result) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run sentiment analysis on a topic with bot/disinfo detection", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("topic", help="The topic to analyze sentiment for") + parser.add_argument( + "--sources", nargs="*", help="Specific URLs or sources to also analyze" + ) + + safety = parser.add_argument_group("safety limits") + safety.add_argument( + "--max-turns", type=int, default=20, help="Max agent turns" + ) + safety.add_argument( + "--max-budget", type=float, default=0.50, help="Max Claude API spend (USD)" + ) + safety.add_argument( + "--max-api-calls", type=int, default=50, help="Max total API calls across all platforms" + ) + + credibility = parser.add_argument_group("credibility filtering") + credibility.add_argument( + "--min-credibility", type=float, default=0.3, + help="Posts below this score are excluded (0.0-1.0)", + ) + credibility.add_argument( + "--flag-threshold", type=float, default=0.5, + help="Posts between min and this are flagged but included (0.0-1.0)", + ) + + args = parser.parse_args() + anyio.run(async_main, args) + + +if __name__ == "__main__": + main() diff --git a/agentstuff/sentiment_agent/ratelimit.py b/agentstuff/sentiment_agent/ratelimit.py new file mode 100644 index 0000000..17b4d24 --- /dev/null +++ b/agentstuff/sentiment_agent/ratelimit.py @@ -0,0 +1,169 @@ +"""Rate limiter and API call budget tracker. + +Enforces per-platform rate limits and a global call budget so the agent +can't hammer APIs or run up unbounded costs. +""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass, field + +from sentiment_agent.config import RateLimitConfig + + +class BudgetExhaustedError(Exception): + """Raised when the global API call budget is spent.""" + + +class RateLimitExceededError(Exception): + """Raised when a platform's rate limit is hit and cooldown hasn't elapsed.""" + + +@dataclass +class _PlatformState: + """Tracks call timestamps and active request count for one platform.""" + + config: RateLimitConfig + call_timestamps: list[float] = field(default_factory=list) + active_requests: int = 0 + last_429_at: float = 0.0 + + +class RateLimiter: + """Manages rate limiting across all platforms + a global call budget. + + Usage: + limiter = RateLimiter(max_total_calls=50) + limiter.register_platform("reddit", RateLimitConfig(...)) + + async with limiter.acquire("reddit"): + await do_reddit_call() + """ + + def __init__(self, max_total_calls: int = 50): + self._max_total = max_total_calls + self._total_calls = 0 + self._platforms: dict[str, _PlatformState] = {} + self._lock = asyncio.Lock() + + @property + def total_calls(self) -> int: + return self._total_calls + + @property + def remaining_calls(self) -> int: + return max(0, self._max_total - self._total_calls) + + def register_platform(self, name: str, config: RateLimitConfig) -> None: + self._platforms[name] = _PlatformState(config=config) + + def acquire(self, platform: str) -> _AcquireContext: + """Context manager that enforces rate limits before allowing a call.""" + return _AcquireContext(self, platform) + + async def _acquire(self, platform: str) -> None: + async with self._lock: + if self._total_calls >= self._max_total: + raise BudgetExhaustedError( + f"Global API call budget exhausted ({self._max_total} calls). " + "Increase max_total_api_calls in SafetyConfig to allow more." + ) + + state = self._platforms.get(platform) + if not state: + raise ValueError(f"Platform '{platform}' not registered with rate limiter") + + now = time.monotonic() + + # Check 429 cooldown + if state.last_429_at: + elapsed = now - state.last_429_at + if elapsed < state.config.cooldown_after_429: + remaining = state.config.cooldown_after_429 - elapsed + raise RateLimitExceededError( + f"Platform '{platform}' is in cooldown after 429. " + f"Try again in {remaining:.0f}s." + ) + state.last_429_at = 0.0 + + # Check burst limit + if state.active_requests >= state.config.burst_size: + raise RateLimitExceededError( + f"Platform '{platform}' burst limit reached " + f"({state.config.burst_size} concurrent). Wait for a request to finish." + ) + + # Check RPM: discard timestamps older than 60s, then check count + cutoff = now - 60.0 + state.call_timestamps = [t for t in state.call_timestamps if t > cutoff] + + if len(state.call_timestamps) >= state.config.requests_per_minute: + oldest = state.call_timestamps[0] + wait_time = 60.0 - (now - oldest) + raise RateLimitExceededError( + f"Platform '{platform}' rate limit: {state.config.requests_per_minute}/min. " + f"Try again in {wait_time:.0f}s." + ) + + # All clear — record the call + state.call_timestamps.append(now) + state.active_requests += 1 + self._total_calls += 1 + + async def _release(self, platform: str) -> None: + async with self._lock: + state = self._platforms.get(platform) + if state: + state.active_requests = max(0, state.active_requests - 1) + + def record_429(self, platform: str) -> None: + """Call this when an API returns 429 to trigger cooldown.""" + state = self._platforms.get(platform) + if state: + state.last_429_at = time.monotonic() + + def get_stats(self) -> dict: + """Return current usage stats for logging/reporting.""" + stats: dict = { + "total_calls": self._total_calls, + "remaining_calls": self.remaining_calls, + "platforms": {}, + } + for name, state in self._platforms.items(): + now = time.monotonic() + cutoff = now - 60.0 + recent = [t for t in state.call_timestamps if t > cutoff] + stats["platforms"][name] = { + "calls_last_60s": len(recent), + "active_requests": state.active_requests, + "rpm_limit": state.config.requests_per_minute, + "in_cooldown": bool( + state.last_429_at + and (now - state.last_429_at) < state.config.cooldown_after_429 + ), + } + return stats + + +class _AcquireContext: + """Async context manager for rate-limited API calls.""" + + def __init__(self, limiter: RateLimiter, platform: str): + self._limiter = limiter + self._platform = platform + + async def __aenter__(self) -> None: + await self._limiter._acquire(self._platform) + + async def __aexit__(self, *exc_info) -> None: + # Check if the call got a 429 + if exc_info[0] is not None: + import httpx + + exc = exc_info[1] + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429: + self._limiter.record_429(self._platform) + + await self._limiter._release(self._platform) diff --git a/agentstuff/sentiment_agent/tools.py b/agentstuff/sentiment_agent/tools.py new file mode 100644 index 0000000..4706f00 --- /dev/null +++ b/agentstuff/sentiment_agent/tools.py @@ -0,0 +1,352 @@ +"""Custom MCP tools for social media and forum data gathering. + +Each tool wraps an API client, enforces rate limits, runs credibility +scoring, and returns MCP-formatted results with bot/disinfo annotations. +""" + +from __future__ import annotations + +import json +import traceback + +from claude_agent_sdk import tool, create_sdk_mcp_server + +from sentiment_agent.clients import bluesky, reddit, hackernews +from sentiment_agent.config import SafetyConfig +from sentiment_agent.credibility import ( + detect_coordination, + filter_and_annotate, + score_bluesky_post, + score_hackernews_comment, + score_hackernews_post, + score_reddit_comment, + score_reddit_post, +) +from sentiment_agent.ratelimit import BudgetExhaustedError, RateLimiter + +# Module-level state — initialized by create_social_tools_server() +_limiter: RateLimiter | None = None +_config: SafetyConfig | None = None + + +def _get_limiter() -> RateLimiter: + if _limiter is None: + raise RuntimeError("Tools not initialized — call create_social_tools_server() first") + return _limiter + + +def _get_config() -> SafetyConfig: + if _config is None: + return SafetyConfig() + return _config + + +def _text_result(text: str) -> dict: + return {"content": [{"type": "text", "text": text}]} + + +def _error_result(error: str) -> dict: + return {"content": [{"type": "text", "text": f"Error: {error}"}], "isError": True} + + +def _clamp_limit(requested: int) -> int: + """Enforce max results per call.""" + return min(requested, _get_config().max_results_per_call) + + +def _format_with_stats( + posts: list[dict], + stats: dict, + coordination_warnings: list[str], + platform: str, +) -> str: + """Format results with credibility stats prepended.""" + header_parts = [ + f"Platform: {platform}", + f"Results: {stats['authentic']} authentic, {stats['flagged']} flagged, " + f"{stats['excluded']} excluded (of {stats['total']} total)", + ] + if coordination_warnings: + header_parts.append("--- COORDINATION ALERTS ---") + header_parts.extend(coordination_warnings) + header_parts.append("---") + + limiter = _get_limiter() + header_parts.append(f"API budget remaining: {limiter.remaining_calls} calls") + + header = "\n".join(header_parts) + body = json.dumps(posts, indent=2, default=str) + return f"{header}\n\n{body}" + + +# --- Bluesky tools --- + + +@tool( + "search_bluesky", + "Search Bluesky for posts about a topic. Returns posts with text, author, " + "engagement metrics, credibility scores, and bot/disinfo flags. " + "Requires BLUESKY_HANDLE and BLUESKY_APP_PASSWORD env vars.", + {"query": str, "limit": int, "sort": str}, +) +async def search_bluesky(args: dict) -> dict: + try: + limiter = _get_limiter() + config = _get_config() + + async with limiter.acquire("bluesky"): + posts = await bluesky.search_posts( + query=args["query"], + limit=_clamp_limit(args.get("limit", 25)), + sort=args.get("sort", "top"), + ) + + if not posts: + return _text_result(f"No Bluesky posts found for: {args['query']}") + + coordination = detect_coordination(posts, text_key="text") + filtered, stats = filter_and_annotate( + posts, score_bluesky_post, + min_score=config.min_credibility_score, + flag_threshold=config.flag_bot_threshold, + ) + return _text_result(_format_with_stats(filtered, stats, coordination, "Bluesky")) + except BudgetExhaustedError as e: + return _error_result(str(e)) + except Exception as e: + return _error_result(f"Bluesky search failed: {e}\n{traceback.format_exc()}") + + +@tool( + "get_bluesky_thread", + "Fetch a Bluesky thread/post and its replies with credibility scoring. " + "Accepts an at:// URI or https://bsky.app/... URL.", + {"uri": str, "depth": int}, +) +async def get_bluesky_thread(args: dict) -> dict: + try: + limiter = _get_limiter() + config = _get_config() + + async with limiter.acquire("bluesky"): + thread = await bluesky.get_thread( + uri=args["uri"], + depth=args.get("depth", 6), + ) + + # Score replies + if thread.get("replies"): + coordination = detect_coordination(thread["replies"], text_key="text") + filtered_replies, stats = filter_and_annotate( + thread["replies"], score_bluesky_post, + min_score=config.min_credibility_score, + flag_threshold=config.flag_bot_threshold, + ) + thread["replies"] = filtered_replies + thread["_reply_credibility_stats"] = stats + thread["_coordination_warnings"] = coordination + + # Score root post + if thread.get("post"): + result = score_bluesky_post(thread["post"]) + thread["post"]["_credibility"] = { + "score": round(result.score, 2), + "label": result.label, + "flags": result.flags, + } + + return _text_result(json.dumps(thread, indent=2, default=str)) + except BudgetExhaustedError as e: + return _error_result(str(e)) + except Exception as e: + return _error_result(f"Bluesky thread fetch failed: {e}\n{traceback.format_exc()}") + + +# --- Reddit tools --- + + +@tool( + "search_reddit", + "Search Reddit for posts about a topic. Returns posts with credibility scores " + "and bot/disinfo flags. Posts below the credibility threshold are auto-excluded. " + "Use subreddit='all' for site-wide or specify a subreddit name.", + {"query": str, "subreddit": str, "sort": str, "time_filter": str, "limit": int}, +) +async def search_reddit_tool(args: dict) -> dict: + try: + limiter = _get_limiter() + config = _get_config() + + async with limiter.acquire("reddit"): + posts = await reddit.search_posts( + query=args["query"], + subreddit=args.get("subreddit", "all"), + sort=args.get("sort", "relevance"), + time_filter=args.get("time_filter", "month"), + limit=_clamp_limit(args.get("limit", 25)), + ) + + if not posts: + return _text_result(f"No Reddit posts found for: {args['query']}") + + coordination = detect_coordination(posts, text_key="title") + filtered, stats = filter_and_annotate( + posts, score_reddit_post, + min_score=config.min_credibility_score, + flag_threshold=config.flag_bot_threshold, + ) + return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit")) + except BudgetExhaustedError as e: + return _error_result(str(e)) + except Exception as e: + return _error_result(f"Reddit search failed: {e}\n{traceback.format_exc()}") + + +@tool( + "get_reddit_comments", + "Fetch comments for a Reddit post with credibility scoring. " + "Pass the permalink path or full URL.", + {"permalink": str, "sort": str, "limit": int}, +) +async def get_reddit_comments(args: dict) -> dict: + try: + limiter = _get_limiter() + config = _get_config() + + async with limiter.acquire("reddit"): + comments = await reddit.get_post_comments( + permalink=args["permalink"], + sort=args.get("sort", "top"), + limit=_clamp_limit(args.get("limit", 25)), + ) + + if not comments: + return _text_result("No comments found for this post.") + + coordination = detect_coordination(comments, text_key="body") + filtered, stats = filter_and_annotate( + comments, score_reddit_comment, + min_score=config.min_credibility_score, + flag_threshold=config.flag_bot_threshold, + ) + return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit Comments")) + except BudgetExhaustedError as e: + return _error_result(str(e)) + except Exception as e: + return _error_result(f"Reddit comments fetch failed: {e}\n{traceback.format_exc()}") + + +# --- Hacker News tools --- + + +@tool( + "search_hackernews", + "Search Hacker News for stories with credibility scoring. " + "No authentication required.", + {"query": str, "limit": int}, +) +async def search_hackernews_tool(args: dict) -> dict: + try: + limiter = _get_limiter() + config = _get_config() + + async with limiter.acquire("hackernews"): + stories = await hackernews.search_stories( + query=args["query"], + limit=_clamp_limit(args.get("limit", 25)), + ) + + if not stories: + return _text_result(f"No HN stories found for: {args['query']}") + + coordination = detect_coordination(stories, text_key="title") + filtered, stats = filter_and_annotate( + stories, score_hackernews_post, + min_score=config.min_credibility_score, + flag_threshold=config.flag_bot_threshold, + ) + return _text_result(_format_with_stats(filtered, stats, coordination, "Hacker News")) + except BudgetExhaustedError as e: + return _error_result(str(e)) + except Exception as e: + return _error_result(f"HN search failed: {e}\n{traceback.format_exc()}") + + +@tool( + "search_hackernews_comments", + "Search Hacker News comments for opinions and discussions with credibility scoring.", + {"query": str, "limit": int}, +) +async def search_hackernews_comments(args: dict) -> dict: + try: + limiter = _get_limiter() + config = _get_config() + + async with limiter.acquire("hackernews"): + comments = await hackernews.search_comments( + query=args["query"], + limit=_clamp_limit(args.get("limit", 25)), + ) + + if not comments: + return _text_result(f"No HN comments found for: {args['query']}") + + coordination = detect_coordination(comments, text_key="comment_text") + filtered, stats = filter_and_annotate( + comments, score_hackernews_comment, + min_score=config.min_credibility_score, + flag_threshold=config.flag_bot_threshold, + ) + return _text_result( + _format_with_stats(filtered, stats, coordination, "HN Comments") + ) + except BudgetExhaustedError as e: + return _error_result(str(e)) + except Exception as e: + return _error_result(f"HN comment search failed: {e}\n{traceback.format_exc()}") + + +# --- Budget status tool --- + + +@tool( + "get_api_budget_status", + "Check remaining API call budget, rate limit status, and per-platform stats. " + "Use this before making more API calls to avoid hitting limits.", + {}, +) +async def get_api_budget_status(args: dict) -> dict: + limiter = _get_limiter() + stats = limiter.get_stats() + return _text_result(json.dumps(stats, indent=2, default=str)) + + +# --- Server factory --- + + +def create_social_tools_server(config: SafetyConfig | None = None): + """Create an MCP server with all social media/forum tools. + + Initializes rate limiting and credibility thresholds from config. + """ + global _limiter, _config + + _config = config or SafetyConfig.from_env() + + _limiter = RateLimiter(max_total_calls=_config.max_total_api_calls) + _limiter.register_platform("bluesky", _config.bluesky_rate) + _limiter.register_platform("reddit", _config.reddit_rate) + _limiter.register_platform("hackernews", _config.hackernews_rate) + + return create_sdk_mcp_server( + "social-tools", + tools=[ + search_bluesky, + get_bluesky_thread, + search_reddit_tool, + get_reddit_comments, + search_hackernews_tool, + search_hackernews_comments, + get_api_budget_status, + ], + ) diff --git a/frontends/web/app.py b/frontends/web/app.py index 81fe8e2..e9bab4a 100644 --- a/frontends/web/app.py +++ b/frontends/web/app.py @@ -1119,7 +1119,7 @@ def _run_encode_job(job_id: str, encode_params: dict) -> None: filename = encode_result.filename if not filename: - filename = generate_filename("stego", output_ext) + filename = generate_filename(prefix="stego", extension=output_ext.lstrip(".")) elif embed_mode == "dct" and dct_output_format == "jpeg" and filename.endswith(".png"): filename = filename[:-4] + ".jpg" @@ -1210,7 +1210,7 @@ def _run_encode_audio_job(job_id: str, encode_params: dict) -> None: ) return - filename = generate_filename("stego_audio", ".wav") + filename = generate_filename(prefix="stego_audio", extension="wav") file_id = secrets.token_urlsafe(16) temp_storage.save_temp_file( file_id, @@ -1273,6 +1273,9 @@ def encode_page(): if carrier_type == "audio": # ========== AUDIO ENCODE PATH (v4.3.0) ========== + # Audio carrier uses a separate form field to avoid name collision + carrier = request.files.get("audio_carrier") or carrier + if not HAS_AUDIO_SUPPORT: return _error_response( "Audio steganography is not available. Install audio dependencies." @@ -1439,7 +1442,7 @@ def encode_page(): error_msg = encode_result.error or "Audio encoding failed" return _error_response(error_msg) - filename = generate_filename("stego_audio", ".wav") + filename = generate_filename(prefix="stego_audio", extension="wav") file_id = secrets.token_urlsafe(16) cleanup_temp_files() temp_storage.save_temp_file( @@ -1676,7 +1679,7 @@ def encode_page(): # Use filename from result or generate one filename = encode_result.filename if not filename: - filename = generate_filename("stego", output_ext) + filename = generate_filename(prefix="stego", extension=output_ext.lstrip(".")) elif embed_mode == "dct" and dct_output_format == "jpeg" and filename.endswith(".png"): filename = filename[:-4] + ".jpg" @@ -2029,6 +2032,9 @@ def decode_page(): if carrier_type == "audio": # ========== AUDIO DECODE PATH (v4.3.0) ========== + # Audio stego uses a separate form field to avoid name collision + stego_image = request.files.get("stego_audio") or stego_image + if not HAS_AUDIO_SUPPORT: flash("Audio steganography is not available.", "error") return render_template("decode.html", has_qrcode_read=HAS_QRCODE_READ) diff --git a/frontends/web/static/js/stegasoo.js b/frontends/web/static/js/stegasoo.js index 75eff42..1987e3f 100644 --- a/frontends/web/static/js/stegasoo.js +++ b/frontends/web/static/js/stegasoo.js @@ -974,13 +974,13 @@ const Stegasoo = { body: formData, }); + const result = await response.json().catch(() => null); + if (!response.ok) { - throw new Error('Failed to start encode'); + throw new Error((result && result.error) || 'Failed to start encode'); } - const result = await response.json(); - - if (result.error) { + if (result && result.error) { throw new Error(result.error); } diff --git a/frontends/web/templates/decode.html b/frontends/web/templates/decode.html index 2d0eda4..c604ed8 100644 --- a/frontends/web/templates/decode.html +++ b/frontends/web/templates/decode.html @@ -280,7 +280,7 @@ Stego Audio
- +
Drop audio or click diff --git a/frontends/web/templates/encode.html b/frontends/web/templates/encode.html index f8ea03e..e8957cc 100644 --- a/frontends/web/templates/encode.html +++ b/frontends/web/templates/encode.html @@ -240,6 +240,11 @@
+ +
+ +
+
@@ -541,6 +546,82 @@ const audioModeGroup = document.getElementById('audioModeGroup'); const capacityPanel = document.getElementById('capacityPanel'); const audioCapacityPanel = document.getElementById('audioCapacityPanel'); +// Capacity tracking for client-side payload size validation +let capacityBytes = { dct: 0, lsb: 0, audio_lsb: 0, audio_spread: 0 }; + +function checkCapacity() { + const warning = document.getElementById('capacityWarning'); + const warningText = document.getElementById('capacityWarningText'); + const encodeBtn = document.getElementById('encodeBtn'); + if (!warning || !warningText || !encodeBtn) return; + + // Determine payload size + const isText = document.getElementById('payloadText')?.checked; + let payloadSize = 0; + if (isText) { + const msg = document.getElementById('messageInput')?.value || ''; + if (msg) payloadSize = new Blob([msg]).size; + } else { + const file = document.getElementById('payloadFileInput')?.files[0]; + if (file) payloadSize = file.size; + } + + // Get active mode + const mode = document.querySelector('input[name="embed_mode"]:checked')?.value || 'lsb'; + const cap = capacityBytes[mode] || 0; + + // Update char percent to use real capacity + if (isText) { + const charPercent = document.getElementById('charPercent'); + if (charPercent) { + const effectiveCap = cap > 0 ? cap : 250000; + charPercent.textContent = Math.round((payloadSize / effectiveCap) * 100) + '%'; + } + } + + // Reset badge colors + const badgeMap = { + dct: 'dctCapacityBadge', + lsb: 'lsbCapacityBadge', + audio_lsb: 'lsbAudioCapacityBadge', + audio_spread: 'spreadCapacityBadge' + }; + + // Restore default badge colors + const dctBadge = document.getElementById('dctCapacityBadge'); + const lsbBadge = document.getElementById('lsbCapacityBadge'); + const audioLsbBadge = document.getElementById('lsbAudioCapacityBadge'); + const spreadBadge = document.getElementById('spreadCapacityBadge'); + if (dctBadge) { dctBadge.classList.remove('bg-danger'); dctBadge.classList.add('bg-warning'); } + if (lsbBadge) { lsbBadge.classList.remove('bg-danger'); lsbBadge.classList.add('bg-primary'); } + if (audioLsbBadge) { audioLsbBadge.classList.remove('bg-danger'); audioLsbBadge.classList.add('bg-primary'); } + if (spreadBadge) { spreadBadge.classList.remove('bg-danger'); spreadBadge.classList.add('bg-warning'); } + + // No carrier or no payload — clear warning + if (cap === 0 || payloadSize === 0) { + warning.classList.add('d-none'); + encodeBtn.disabled = false; + return; + } + + if (payloadSize > cap) { + // Exceeds capacity — show warning, turn badge red, disable button + const activeBadge = document.getElementById(badgeMap[mode]); + if (activeBadge) { + activeBadge.classList.remove('bg-primary', 'bg-warning'); + activeBadge.classList.add('bg-danger'); + } + const needed = (payloadSize / 1024).toFixed(1); + const available = (cap / 1024).toFixed(1); + warningText.textContent = `Payload too large: ${needed} KB needed, only ${available} KB capacity in ${mode.replace('_', ' ').toUpperCase()} mode`; + warning.classList.remove('d-none'); + encodeBtn.disabled = true; + } else { + warning.classList.add('d-none'); + encodeBtn.disabled = false; + } +} + carrierTypeRadios.forEach(radio => { radio.addEventListener('change', function() { const isAudio = this.value === 'audio'; @@ -560,9 +641,17 @@ carrierTypeRadios.forEach(radio => { if (imageModeGroup) imageModeGroup.classList.toggle('d-none', isAudio); if (audioModeGroup) audioModeGroup.classList.toggle('d-none', !isAudio); - // Toggle capacity panels + // Toggle capacity panels and reset capacity values if (capacityPanel) capacityPanel.classList.add('d-none'); if (audioCapacityPanel) audioCapacityPanel.classList.add('d-none'); + if (isAudio) { + capacityBytes.dct = 0; + capacityBytes.lsb = 0; + } else { + capacityBytes.audio_lsb = 0; + capacityBytes.audio_spread = 0; + } + checkCapacity(); // Select default mode for the active type and update hint if (isAudio) { @@ -621,7 +710,10 @@ audioCarrierInput?.addEventListener('change', function() { document.getElementById('audioInfo').textContent = info; document.getElementById('lsbAudioCapacityBadge').textContent = `LSB: ${(data.lsb_capacity / 1024).toFixed(1)} KB`; document.getElementById('spreadCapacityBadge').textContent = `Spread: ${(data.spread_capacity / 1024).toFixed(1)} KB`; + capacityBytes.audio_lsb = data.lsb_capacity; + capacityBytes.audio_spread = data.spread_capacity; document.getElementById('audioCapacityPanel')?.classList.remove('d-none'); + checkCapacity(); if (data.duration) { document.getElementById('audioCarrierDuration').textContent = data.duration + 's duration'; } @@ -763,6 +855,7 @@ function updatePayloadSection() { payloadFileInput.setAttribute('required', ''); } updatePayloadSummary(); + checkCapacity(); } payloadTextRadio?.addEventListener('change', updatePayloadSection); @@ -786,6 +879,7 @@ payloadFileInput?.addEventListener('change', function() { } else { fileInfo?.classList.add('d-none'); } + checkCapacity(); }); // ============================================================================ @@ -795,7 +889,7 @@ payloadFileInput?.addEventListener('change', function() { messageInput?.addEventListener('input', function() { const count = this.value.length; document.getElementById('charCount').textContent = count.toLocaleString(); - document.getElementById('charPercent').textContent = Math.round((count / 250000) * 100) + '%'; + checkCapacity(); }); // ============================================================================ @@ -814,7 +908,10 @@ carrierInput?.addEventListener('change', function() { document.getElementById('carrierDimensions').textContent = `${data.width} x ${data.height}`; document.getElementById('lsbCapacityBadge').textContent = `LSB: ${data.lsb.capacity_kb} KB`; document.getElementById('dctCapacityBadge').textContent = `DCT: ${data.dct.capacity_kb} KB`; + capacityBytes.lsb = Math.round(data.lsb.capacity_kb * 1024); + capacityBytes.dct = Math.round(data.dct.capacity_kb * 1024); document.getElementById('capacityPanel')?.classList.remove('d-none'); + checkCapacity(); }).catch(() => {}); } }); @@ -859,7 +956,7 @@ function updateOutputOptions(mode) { } modeRadios.forEach(radio => { - radio.addEventListener('change', () => updateOutputOptions(radio.value)); + radio.addEventListener('change', () => { updateOutputOptions(radio.value); checkCapacity(); }); }); // Initialize output options based on initial mode diff --git a/src/stegasoo/audio_steganography.py b/src/stegasoo/audio_steganography.py index cf5f870..6397ff0 100644 --- a/src/stegasoo/audio_steganography.py +++ b/src/stegasoo/audio_steganography.py @@ -264,18 +264,26 @@ def embed_in_audio_lsb( debug.validate(len(sample_key) == 32, f"Sample key must be 32 bytes, got {len(sample_key)}") try: - # 1. Read carrier audio - samples, samplerate = sf.read(io.BytesIO(carrier_audio), dtype="int16", always_2d=True) - # samples shape: (num_frames, channels) - original_shape = samples.shape + # 1. Read carrier audio as float64 (handles all subtypes correctly) + buf = io.BytesIO(carrier_audio) + float_samples, samplerate = sf.read(buf, dtype="float64", always_2d=True) + original_shape = float_samples.shape channels = original_shape[1] duration = original_shape[0] / samplerate + # Detect original subtype for output + buf.seek(0) + carrier_info = sf.info(buf) + output_subtype = carrier_info.subtype or "PCM_16" + debug.print( f"Carrier audio: {samplerate} Hz, {channels} ch, " - f"{original_shape[0]} frames, {duration:.2f}s" + f"{original_shape[0]} frames, {duration:.2f}s, subtype={output_subtype}" ) + # Convert float64 → int16 for LSB manipulation (32768 matches libsndfile normalization) + samples = np.clip(float_samples * 32768.0, -32768, 32767).astype(np.int16) + # Flatten to 1D for embedding flat_samples = samples.flatten().copy() num_samples = len(flat_samples) @@ -347,7 +355,9 @@ def embed_in_audio_lsb( debug.print(f"Modified {modified_count} samples (out of {samples_needed} selected)") - # 7. Reshape and write back as WAV + # 7. Reshape and write back as PCM_16 WAV + # LSB steganography requires integer samples — writing as FLOAT/DOUBLE + # destroys LSBs due to float32 precision loss (33k/65k values fail round-trip). stego_samples = flat_samples.reshape(original_shape) output_buf = io.BytesIO() @@ -412,7 +422,7 @@ def extract_from_audio_lsb( ) try: - # 1. Read audio + # 1. Read audio as int16 directly (stego output is always PCM_16) samples, samplerate = sf.read(io.BytesIO(audio_data), dtype="int16", always_2d=True) flat_samples = samples.flatten() num_samples = len(flat_samples) diff --git a/src/stegasoo/constants.py b/src/stegasoo/constants.py index d5a8e8d..8d30235 100644 --- a/src/stegasoo/constants.py +++ b/src/stegasoo/constants.py @@ -383,6 +383,12 @@ ALLOWED_AUDIO_EXTENSIONS = {"wav", "flac", "mp3", "ogg", "opus", "aac", "m4a", " # Spread spectrum parameters AUDIO_SS_CHIP_LENGTH = 1024 # Samples per chip (spreading factor) — legacy/default AUDIO_SS_AMPLITUDE = 0.05 # Per-sample embedding strength (~-26dB, masked by audio) + +# Adaptive amplitude: embed at a fixed ratio below the carrier's RMS level. +# Keeps noise inaudible under content while ensuring extraction reliability. +AUDIO_SS_AMPLITUDE_RATIO = 0.25 # Fraction of carrier RMS (≈ -12 dB below signal) +AUDIO_SS_AMPLITUDE_MIN = 0.001 # Floor: ensures correlation ≥ 0.256 at chip=256 +AUDIO_SS_AMPLITUDE_MAX = 0.05 # Ceiling: never exceed original fixed amplitude AUDIO_SS_RS_NSYM = 32 # Reed-Solomon parity symbols # Spread spectrum v2: per-channel hybrid embedding (v4.4.0) @@ -406,6 +412,14 @@ AUDIO_SS_CHIP_TIER_NAMES = { AUDIO_LFE_CHANNEL_INDEX = 3 # Standard WAV/WAVEFORMATEXTENSIBLE ordering AUDIO_LFE_MIN_CHANNELS = 6 # Only skip LFE for 5.1+ layouts +# Compact padding for audio encryption (limited carrier capacity) +AUDIO_PAD_MIN = 8 # Minimum random padding bytes (vs 64 for images) +AUDIO_PAD_RANGE = 32 # Random padding range (vs 256 for images) +AUDIO_PAD_ALIGN = 32 # Alignment boundary (vs 256 for images) + +# Lossless audio formats (safe for low chip tier) +LOSSLESS_AUDIO_FORMATS = {"wav", "flac", "aiff"} + # Echo hiding parameters AUDIO_ECHO_DELAY_0 = 50 # Echo delay for bit 0 (samples at 44.1kHz ~ 1.1ms) AUDIO_ECHO_DELAY_1 = 100 # Echo delay for bit 1 (samples at 44.1kHz ~ 2.3ms) diff --git a/src/stegasoo/crypto.py b/src/stegasoo/crypto.py index 3536c2e..5c12a58 100644 --- a/src/stegasoo/crypto.py +++ b/src/stegasoo/crypto.py @@ -36,6 +36,9 @@ from .constants import ( ARGON2_MEMORY_COST, ARGON2_PARALLELISM, ARGON2_TIME_COST, + AUDIO_PAD_ALIGN, + AUDIO_PAD_MIN, + AUDIO_PAD_RANGE, FORMAT_VERSION, IV_SIZE, MAGIC_HEADER, @@ -433,6 +436,7 @@ def encrypt_message( pin: str = "", rsa_key_data: bytes | None = None, channel_key: str | bool | None = None, + compact: bool = False, ) -> bytes: """ Encrypt message or file using AES-256-GCM. @@ -492,8 +496,17 @@ def encrypt_message( ) # Random padding to hide message length - padding_len = secrets.randbelow(256) + 64 - padded_len = ((len(packed_payload) + padding_len + 255) // 256) * 256 + # Compact mode uses smaller padding for capacity-limited carriers (audio) + if compact: + pad_min = AUDIO_PAD_MIN + pad_range = AUDIO_PAD_RANGE + pad_align = AUDIO_PAD_ALIGN + else: + pad_min = 64 + pad_range = 256 + pad_align = 256 + padding_len = secrets.randbelow(pad_range) + pad_min + padded_len = ((len(packed_payload) + padding_len + pad_align - 1) // pad_align) * pad_align padding_needed = padded_len - len(packed_payload) padding = secrets.token_bytes(padding_needed - 4) + struct.pack(">I", len(packed_payload)) padded_message = packed_payload + padding diff --git a/src/stegasoo/encode.py b/src/stegasoo/encode.py index 604d2b0..eaa7d8b 100644 --- a/src/stegasoo/encode.py +++ b/src/stegasoo/encode.py @@ -337,9 +337,10 @@ def encode_audio( debug.print(f"Transcoding {audio_format} to WAV for embedding") carrier_audio = transcode_to_wav(carrier_audio) - # Encrypt message + # Encrypt message (compact padding for audio's limited capacity) encrypted = encrypt_message( - message, reference_photo, passphrase, pin, rsa_key_data, channel_key + message, reference_photo, passphrase, pin, rsa_key_data, channel_key, + compact=True, ) debug.print(f"Encrypted payload: {len(encrypted)} bytes") @@ -354,10 +355,21 @@ def encode_audio( encrypted, carrier_audio, pixel_key, progress_file=progress_file ) elif embed_mode == EMBED_MODE_AUDIO_SPREAD: - from .constants import AUDIO_SS_DEFAULT_CHIP_TIER + from .constants import ( + AUDIO_SS_CHIP_TIER_LOSSLESS, + AUDIO_SS_DEFAULT_CHIP_TIER, + LOSSLESS_AUDIO_FORMATS, + ) from .spread_steganography import embed_in_audio_spread - tier = chip_tier if chip_tier is not None else AUDIO_SS_DEFAULT_CHIP_TIER + if chip_tier is not None: + tier = chip_tier + elif audio_format in LOSSLESS_AUDIO_FORMATS: + tier = AUDIO_SS_CHIP_TIER_LOSSLESS + debug.print(f"Auto-selected chip tier 0 (lossless) for {audio_format} carrier") + else: + tier = AUDIO_SS_DEFAULT_CHIP_TIER + debug.print(f"Auto-selected chip tier {tier} (lossy) for {audio_format} carrier") stego_audio, stats = embed_in_audio_spread( encrypted, carrier_audio, pixel_key, chip_tier=tier, progress_file=progress_file ) diff --git a/src/stegasoo/spread_steganography.py b/src/stegasoo/spread_steganography.py index 5f8f08e..ff7dc34 100644 --- a/src/stegasoo/spread_steganography.py +++ b/src/stegasoo/spread_steganography.py @@ -50,6 +50,9 @@ from .constants import ( AUDIO_LFE_MIN_CHANNELS, AUDIO_MAGIC_SPREAD, AUDIO_SS_AMPLITUDE, + AUDIO_SS_AMPLITUDE_MAX, + AUDIO_SS_AMPLITUDE_MIN, + AUDIO_SS_AMPLITUDE_RATIO, AUDIO_SS_CHIP_LENGTH, AUDIO_SS_CHIP_LENGTHS, AUDIO_SS_DEFAULT_CHIP_TIER, @@ -81,6 +84,21 @@ except ImportError: ReedSolomonError = None # type: ignore[assignment,misc] +def _adaptive_amplitude(samples: np.ndarray) -> float: + """ + Compute embedding amplitude scaled to the carrier's signal level. + + Uses AUDIO_SS_AMPLITUDE_RATIO of the carrier RMS, clamped between + AUDIO_SS_AMPLITUDE_MIN and AUDIO_SS_AMPLITUDE_MAX. This keeps the + embedded noise inaudible for quiet recordings while preserving the + original fixed amplitude for loud carriers. + """ + rms = float(np.sqrt(np.mean(samples**2))) + adaptive = rms * AUDIO_SS_AMPLITUDE_RATIO + amplitude = max(AUDIO_SS_AMPLITUDE_MIN, min(AUDIO_SS_AMPLITUDE_MAX, adaptive)) + return amplitude + + # Header sizes _V0_HEADER_SIZE = 16 # Legacy: 4B magic + 3x4B length _V2_HEADER_SIZE = 20 # v2: 4B magic + 1B ver + 1B tier + 1B nch + 1B flags + 3x4B length @@ -687,9 +705,13 @@ def embed_in_audio_spread( lfe_skipped = len(embed_ch) < channels chip_length = AUDIO_SS_CHIP_LENGTHS.get(chip_tier, AUDIO_SS_CHIP_LENGTH) + # Compute adaptive amplitude scaled to carrier signal level + amplitude = _adaptive_amplitude(samples) + debug.print( f"Carrier: {sample_rate} Hz, {channels} ch ({num_embed_ch} embeddable), " - f"{num_frames} frames, {duration:.2f}s, chip={chip_length}" + f"{num_frames} frames, {duration:.2f}s, chip={chip_length}, " + f"amplitude={amplitude:.6f}" ) # 3. RS-encode the payload @@ -709,7 +731,7 @@ def embed_in_audio_spread( samples[:, embed_ch[0]], header_bits, seed, - AUDIO_SS_AMPLITUDE, + amplitude, _HEADER_CHIP_LENGTH, channel_index=0, offset=0, @@ -752,7 +774,7 @@ def embed_in_audio_spread( samples[:, ch], bits_for_ch, seed, - AUDIO_SS_AMPLITUDE, + amplitude, chip_length, channel_index=ch, offset=payload_offset, diff --git a/test_data/stupid_elitist_speech.wav b/test_data/stupid_elitist_speech.wav new file mode 100644 index 0000000..696407e Binary files /dev/null and b/test_data/stupid_elitist_speech.wav differ