Compare commits
2 Commits
pre-monore
...
c970261e53
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c970261e53 | ||
|
|
4607ff27dd |
5
.claude/settings.json
Normal file
5
.claude/settings.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"enabledPlugins": {
|
||||||
|
"church@church": true
|
||||||
|
}
|
||||||
|
}
|
||||||
30
agentstuff/pyproject.toml
Normal file
30
agentstuff/pyproject.toml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
[project]
|
||||||
|
name = "sentiment-agent"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "AI agent for gathering data and performing sentiment analysis"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"claude-agent-sdk",
|
||||||
|
"anyio",
|
||||||
|
"httpx",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest",
|
||||||
|
"ruff",
|
||||||
|
"mypy",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
sentiment-agent = "sentiment_agent.main:main"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
ignore_missing_imports = true
|
||||||
3
agentstuff/sentiment_agent/__init__.py
Normal file
3
agentstuff/sentiment_agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Sentiment analysis agent powered by Claude Agent SDK."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
115
agentstuff/sentiment_agent/agent.py
Normal file
115
agentstuff/sentiment_agent/agent.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""Core sentiment analysis agent using Claude Agent SDK."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
ClaudeAgentOptions,
|
||||||
|
ClaudeSDKClient,
|
||||||
|
ResultMessage,
|
||||||
|
TextBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sentiment_agent.config import SafetyConfig
|
||||||
|
from sentiment_agent.tools import create_social_tools_server
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """\
|
||||||
|
You are a sentiment analysis agent. Your job is to gather data from multiple \
|
||||||
|
platforms and produce a structured, evidence-based sentiment report.
|
||||||
|
|
||||||
|
## Rules — you MUST follow these
|
||||||
|
|
||||||
|
1. **Budget awareness.** You have a limited API call budget. Call \
|
||||||
|
`get_api_budget_status` before starting and after every few tool calls. \
|
||||||
|
Stop gathering data when you have <5 calls remaining and begin your analysis.
|
||||||
|
|
||||||
|
2. **Credibility first.** Every tool result includes credibility scores and \
|
||||||
|
bot/disinfo flags. You MUST:
|
||||||
|
- NEVER quote or cite posts marked `likely_inauthentic` (score < 0.3).
|
||||||
|
- Flag posts marked `suspicious` (score 0.3–0.5) with a warning when citing them.
|
||||||
|
- Give more weight to `likely_authentic` posts (score ≥ 0.7).
|
||||||
|
- If coordination warnings appear (copy-paste campaigns, burst posting), \
|
||||||
|
call them out prominently in your report.
|
||||||
|
|
||||||
|
3. **Platform diversity.** Gather from at least 2 different platforms before \
|
||||||
|
analyzing. Do not over-index on a single source.
|
||||||
|
|
||||||
|
4. **No fabrication.** Only report on data you actually retrieved. If a tool \
|
||||||
|
call fails or returns no results, say so — do not invent data.
|
||||||
|
|
||||||
|
5. **Structured output.** Your final report MUST include these sections:
|
||||||
|
- **Data Quality Summary**: platforms queried, posts analyzed vs excluded, \
|
||||||
|
coordination warnings
|
||||||
|
- **Overall Sentiment**: score (-1.0 to +1.0) and label \
|
||||||
|
(very negative / negative / mixed / neutral / positive / very positive)
|
||||||
|
- **Platform Breakdown**: sentiment per platform with sample size
|
||||||
|
- **Key Themes**: top 3-5 themes with sentiment direction
|
||||||
|
- **Credibility Concerns**: any bot networks, disinfo patterns, or \
|
||||||
|
coordinated campaigns detected
|
||||||
|
- **Notable Quotes**: 3-5 representative quotes (authentic sources only, \
|
||||||
|
with credibility score noted)
|
||||||
|
- **Confidence Assessment**: how confident you are in the analysis given \
|
||||||
|
data quality and volume
|
||||||
|
|
||||||
|
6. **Scope discipline.** Stay focused on the requested topic. Do not expand \
|
||||||
|
scope, follow tangents, or analyze adjacent topics unless explicitly asked.
|
||||||
|
|
||||||
|
7. **No side effects.** Do not write files, run commands, or take any action \
|
||||||
|
beyond reading data and producing your report.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def run_sentiment_analysis(
|
||||||
|
topic: str,
|
||||||
|
sources: list[str] | None = None,
|
||||||
|
config: SafetyConfig | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Run the sentiment analysis agent on a given topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: The topic or subject to analyze sentiment for.
|
||||||
|
sources: Optional list of URLs or data sources to analyze.
|
||||||
|
config: Safety configuration. Defaults to SafetyConfig.from_env().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The agent's sentiment analysis report.
|
||||||
|
"""
|
||||||
|
config = config or SafetyConfig.from_env()
|
||||||
|
|
||||||
|
source_instructions = ""
|
||||||
|
if sources:
|
||||||
|
source_list = "\n".join(f"- {s}" for s in sources)
|
||||||
|
source_instructions = f"\n\nAlso analyze these specific sources:\n{source_list}"
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"Perform a sentiment analysis on the following topic: {topic}\n\n"
|
||||||
|
"Start by calling `get_api_budget_status` to check your budget, then "
|
||||||
|
"gather data from multiple platforms (Reddit, Hacker News, Bluesky if "
|
||||||
|
"configured, and web search). Pay close attention to credibility scores "
|
||||||
|
"and coordination warnings in the results."
|
||||||
|
f"{source_instructions}"
|
||||||
|
)
|
||||||
|
|
||||||
|
social_server = create_social_tools_server(config)
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(
|
||||||
|
# Only allow read-only tools — no Write/Bash to prevent side effects
|
||||||
|
allowed_tools=["WebSearch", "WebFetch", "Read"],
|
||||||
|
max_turns=config.max_turns,
|
||||||
|
max_budget_usd=config.max_budget_usd,
|
||||||
|
mcp_servers={"social": social_server},
|
||||||
|
system_prompt=SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_text = ""
|
||||||
|
async with ClaudeSDKClient(options=options) as client:
|
||||||
|
await client.query(prompt)
|
||||||
|
async for message in client.receive_response():
|
||||||
|
if isinstance(message, AssistantMessage):
|
||||||
|
for block in message.content:
|
||||||
|
if isinstance(block, TextBlock):
|
||||||
|
print(block.text, end="", flush=True)
|
||||||
|
if isinstance(message, ResultMessage):
|
||||||
|
result_text = message.result
|
||||||
|
|
||||||
|
return result_text
|
||||||
1
agentstuff/sentiment_agent/clients/__init__.py
Normal file
1
agentstuff/sentiment_agent/clients/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""API clients for social media and forum data sources."""
|
||||||
166
agentstuff/sentiment_agent/clients/bluesky.py
Normal file
166
agentstuff/sentiment_agent/clients/bluesky.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Bluesky client using the AT Protocol API.
|
||||||
|
|
||||||
|
Search requires authentication. Set BLUESKY_HANDLE and BLUESKY_APP_PASSWORD
|
||||||
|
env vars. Create an app password at: https://bsky.app/settings/app-passwords
|
||||||
|
|
||||||
|
Thread fetching works without auth via the public API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
BSKY_PUBLIC_API = "https://public.api.bsky.app"
|
||||||
|
BSKY_AUTH_API = "https://bsky.social"
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session() -> dict | None:
|
||||||
|
"""Authenticate with Bluesky and return session tokens, or None if no creds."""
|
||||||
|
handle = os.environ.get("BLUESKY_HANDLE")
|
||||||
|
app_password = os.environ.get("BLUESKY_APP_PASSWORD")
|
||||||
|
if not handle or not app_password:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{BSKY_AUTH_API}/xrpc/com.atproto.server.createSession",
|
||||||
|
json={"identifier": handle, "password": app_password},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_post(post_view: dict) -> dict:
|
||||||
|
"""Extract relevant fields from an AT Protocol post view."""
|
||||||
|
post = post_view.get("post", post_view)
|
||||||
|
record = post.get("record", {})
|
||||||
|
author = post.get("author", {})
|
||||||
|
return {
|
||||||
|
"text": record.get("text", ""),
|
||||||
|
"author_handle": author.get("handle", ""),
|
||||||
|
"author_display_name": author.get("displayName", ""),
|
||||||
|
"created_at": record.get("createdAt", ""),
|
||||||
|
"like_count": post.get("likeCount", 0),
|
||||||
|
"repost_count": post.get("repostCount", 0),
|
||||||
|
"reply_count": post.get("replyCount", 0),
|
||||||
|
"uri": post.get("uri", ""),
|
||||||
|
"cid": post.get("cid", ""),
|
||||||
|
"url": _uri_to_url(post.get("uri", ""), author.get("handle", "")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _uri_to_url(uri: str, handle: str) -> str:
|
||||||
|
"""Convert an at:// URI to a bsky.app URL."""
|
||||||
|
# at://did:plc:xxx/app.bsky.feed.post/rkey -> https://bsky.app/profile/handle/post/rkey
|
||||||
|
if not uri.startswith("at://"):
|
||||||
|
return ""
|
||||||
|
parts = uri.split("/")
|
||||||
|
if len(parts) >= 5:
|
||||||
|
rkey = parts[-1]
|
||||||
|
return f"https://bsky.app/profile/{handle}/post/{rkey}"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
async def search_posts(query: str, limit: int = 25, sort: str = "top") -> list[dict]:
|
||||||
|
"""Search Bluesky for posts matching a query.
|
||||||
|
|
||||||
|
Requires BLUESKY_HANDLE and BLUESKY_APP_PASSWORD env vars.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search terms.
|
||||||
|
limit: Max results (capped at 100).
|
||||||
|
sort: "top" (most liked) or "latest" (chronological).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of post dicts with: text, author_handle, author_display_name,
|
||||||
|
created_at, like_count, repost_count, reply_count, uri, url.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If Bluesky credentials are not configured.
|
||||||
|
"""
|
||||||
|
session = await _get_session()
|
||||||
|
if not session:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Bluesky search requires authentication. "
|
||||||
|
"Set BLUESKY_HANDLE and BLUESKY_APP_PASSWORD environment variables. "
|
||||||
|
"Create an app password at: https://bsky.app/settings/app-passwords"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{BSKY_AUTH_API}/xrpc/app.bsky.feed.searchPosts",
|
||||||
|
params={
|
||||||
|
"q": query,
|
||||||
|
"limit": min(limit, 100),
|
||||||
|
"sort": sort,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {session['accessJwt']}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
return [_format_post(p) for p in data.get("posts", [])]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_thread(uri: str, depth: int = 6) -> dict:
|
||||||
|
"""Fetch a Bluesky thread by AT URI or bsky.app URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: Either an at:// URI or a https://bsky.app/profile/.../post/... URL.
|
||||||
|
depth: How many levels of replies to fetch (max 1000).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with "post" (the root post) and "replies" (list of reply post dicts).
|
||||||
|
"""
|
||||||
|
# Convert bsky.app URL to AT URI if needed
|
||||||
|
if uri.startswith("https://bsky.app/"):
|
||||||
|
uri = await _resolve_url_to_uri(uri)
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
session = await _get_session()
|
||||||
|
if session:
|
||||||
|
headers["Authorization"] = f"Bearer {session['accessJwt']}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{BSKY_PUBLIC_API}/xrpc/app.bsky.feed.getPostThread",
|
||||||
|
params={"uri": uri, "depth": min(depth, 1000)},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
thread = data.get("thread", {})
|
||||||
|
root_post = _format_post(thread) if "post" in thread else {}
|
||||||
|
|
||||||
|
replies = []
|
||||||
|
for reply in thread.get("replies", []):
|
||||||
|
if "post" in reply:
|
||||||
|
replies.append(_format_post(reply))
|
||||||
|
# Include nested replies one level deep
|
||||||
|
for nested in reply.get("replies", []):
|
||||||
|
if "post" in nested:
|
||||||
|
replies.append(_format_post(nested))
|
||||||
|
|
||||||
|
return {"post": root_post, "replies": replies}
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_url_to_uri(url: str) -> str:
|
||||||
|
"""Convert a bsky.app URL to an AT URI by resolving the handle."""
|
||||||
|
# https://bsky.app/profile/handle.bsky.social/post/rkey
|
||||||
|
parts = url.rstrip("/").split("/")
|
||||||
|
if len(parts) < 6:
|
||||||
|
raise ValueError(f"Invalid Bluesky URL: {url}")
|
||||||
|
|
||||||
|
handle = parts[4] # profile/{handle}
|
||||||
|
rkey = parts[6] # post/{rkey}
|
||||||
|
|
||||||
|
# Resolve handle to DID
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{BSKY_PUBLIC_API}/xrpc/com.atproto.identity.resolveHandle",
|
||||||
|
params={"handle": handle},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
did = resp.json()["did"]
|
||||||
|
|
||||||
|
return f"at://{did}/app.bsky.feed.post/{rkey}"
|
||||||
78
agentstuff/sentiment_agent/clients/hackernews.py
Normal file
78
agentstuff/sentiment_agent/clients/hackernews.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""Hacker News client using the Algolia HN Search API.
|
||||||
|
|
||||||
|
No authentication required. Docs: https://hn.algolia.com/api
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
HN_API_BASE = "https://hn.algolia.com/api/v1"
|
||||||
|
|
||||||
|
|
||||||
|
async def search_stories(query: str, limit: int = 25) -> list[dict]:
|
||||||
|
"""Search HN for stories matching a query.
|
||||||
|
|
||||||
|
Returns a list of story dicts with: title, url, author, points,
|
||||||
|
num_comments, created_at, objectID, story_text.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{HN_API_BASE}/search",
|
||||||
|
params={
|
||||||
|
"query": query,
|
||||||
|
"tags": "story",
|
||||||
|
"hitsPerPage": min(limit, 50),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for hit in data.get("hits", []):
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"title": hit.get("title", ""),
|
||||||
|
"url": hit.get("url", ""),
|
||||||
|
"author": hit.get("author", ""),
|
||||||
|
"points": hit.get("points", 0),
|
||||||
|
"num_comments": hit.get("num_comments", 0),
|
||||||
|
"created_at": hit.get("created_at", ""),
|
||||||
|
"object_id": hit.get("objectID", ""),
|
||||||
|
"story_text": hit.get("story_text") or "",
|
||||||
|
"hn_url": f"https://news.ycombinator.com/item?id={hit.get('objectID', '')}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def search_comments(query: str, limit: int = 25) -> list[dict]:
|
||||||
|
"""Search HN for comments matching a query.
|
||||||
|
|
||||||
|
Returns a list of comment dicts with: comment_text, author, points,
|
||||||
|
created_at, story_title, story_url.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{HN_API_BASE}/search",
|
||||||
|
params={
|
||||||
|
"query": query,
|
||||||
|
"tags": "comment",
|
||||||
|
"hitsPerPage": min(limit, 50),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for hit in data.get("hits", []):
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"comment_text": hit.get("comment_text", ""),
|
||||||
|
"author": hit.get("author", ""),
|
||||||
|
"points": hit.get("points", 0),
|
||||||
|
"created_at": hit.get("created_at", ""),
|
||||||
|
"story_title": hit.get("story_title", ""),
|
||||||
|
"story_url": hit.get("story_url", ""),
|
||||||
|
"hn_url": f"https://news.ycombinator.com/item?id={hit.get('objectID', '')}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
117
agentstuff/sentiment_agent/clients/reddit.py
Normal file
117
agentstuff/sentiment_agent/clients/reddit.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""Reddit client using the public JSON API.
|
||||||
|
|
||||||
|
No authentication required for read-only search. Reddit requires a descriptive
|
||||||
|
User-Agent header — requests with generic UAs get 429'd.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
REDDIT_BASE = "https://www.reddit.com"
|
||||||
|
USER_AGENT = "sentiment-agent/0.1.0 (research; sentiment analysis tool)"
|
||||||
|
|
||||||
|
|
||||||
|
async def search_posts(
|
||||||
|
query: str,
|
||||||
|
subreddit: str = "all",
|
||||||
|
sort: str = "relevance",
|
||||||
|
time_filter: str = "month",
|
||||||
|
limit: int = 25,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Search Reddit for posts matching a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search terms.
|
||||||
|
subreddit: Subreddit to search within, or "all" for site-wide.
|
||||||
|
sort: One of "relevance", "hot", "top", "new", "comments".
|
||||||
|
time_filter: One of "hour", "day", "week", "month", "year", "all".
|
||||||
|
limit: Max results (capped at 100 by Reddit).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of post dicts with: title, selftext, author, score,
|
||||||
|
num_comments, subreddit, url, permalink, created_utc.
|
||||||
|
"""
|
||||||
|
url = f"{REDDIT_BASE}/r/{subreddit}/search.json"
|
||||||
|
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
url,
|
||||||
|
params={
|
||||||
|
"q": query,
|
||||||
|
"sort": sort,
|
||||||
|
"t": time_filter,
|
||||||
|
"limit": min(limit, 100),
|
||||||
|
"restrict_sr": "on" if subreddit != "all" else "off",
|
||||||
|
},
|
||||||
|
headers={"User-Agent": USER_AGENT},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for child in data.get("data", {}).get("children", []):
|
||||||
|
post = child.get("data", {})
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"title": post.get("title", ""),
|
||||||
|
"selftext": (post.get("selftext") or "")[:2000],
|
||||||
|
"author": post.get("author", "[deleted]"),
|
||||||
|
"score": post.get("score", 0),
|
||||||
|
"upvote_ratio": post.get("upvote_ratio", 0),
|
||||||
|
"num_comments": post.get("num_comments", 0),
|
||||||
|
"subreddit": post.get("subreddit", ""),
|
||||||
|
"url": post.get("url", ""),
|
||||||
|
"permalink": f"https://reddit.com{post.get('permalink', '')}",
|
||||||
|
"created_utc": post.get("created_utc", 0),
|
||||||
|
"is_self": post.get("is_self", False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def get_post_comments(
|
||||||
|
permalink: str,
|
||||||
|
sort: str = "top",
|
||||||
|
limit: int = 25,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Fetch top-level comments for a Reddit post.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
permalink: The post's permalink path (e.g., "/r/python/comments/abc123/title/").
|
||||||
|
sort: Comment sort order: "top", "best", "new", "controversial".
|
||||||
|
limit: Max comments to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of comment dicts with: body, author, score, created_utc.
|
||||||
|
"""
|
||||||
|
# Strip domain if full URL was passed
|
||||||
|
if permalink.startswith("https://"):
|
||||||
|
permalink = permalink.replace("https://reddit.com", "")
|
||||||
|
permalink = permalink.replace("https://www.reddit.com", "")
|
||||||
|
|
||||||
|
url = f"{REDDIT_BASE}{permalink}.json"
|
||||||
|
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
url,
|
||||||
|
params={"sort": sort, "limit": limit},
|
||||||
|
headers={"User-Agent": USER_AGENT},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
# Reddit returns [post_listing, comments_listing]
|
||||||
|
if not isinstance(data, list) or len(data) < 2:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for child in data[1].get("data", {}).get("children", []):
|
||||||
|
if child.get("kind") != "t1":
|
||||||
|
continue
|
||||||
|
comment = child.get("data", {})
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"body": (comment.get("body") or "")[:2000],
|
||||||
|
"author": comment.get("author", "[deleted]"),
|
||||||
|
"score": comment.get("score", 0),
|
||||||
|
"created_utc": comment.get("created_utc", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
70
agentstuff/sentiment_agent/config.py
Normal file
70
agentstuff/sentiment_agent/config.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""Configuration and safety limits for the sentiment agent.
|
||||||
|
|
||||||
|
All guardrails are centralized here so they can be tuned from one place
|
||||||
|
or overridden via CLI flags / env vars.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""Per-platform rate limiting."""
|
||||||
|
|
||||||
|
requests_per_minute: int = 10
|
||||||
|
burst_size: int = 3 # max concurrent requests
|
||||||
|
cooldown_after_429: float = 30.0 # seconds to wait after a 429
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SafetyConfig:
|
||||||
|
"""Top-level safety rails for the agent."""
|
||||||
|
|
||||||
|
# --- Agent-level limits ---
|
||||||
|
max_turns: int = 20
|
||||||
|
max_budget_usd: float = 0.50 # hard cap on Claude API spend per run
|
||||||
|
max_total_api_calls: int = 50 # across ALL platforms combined
|
||||||
|
max_results_per_call: int = 50 # cap the `limit` param sent to any API
|
||||||
|
|
||||||
|
# --- Per-platform rate limits ---
|
||||||
|
bluesky_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig(
|
||||||
|
requests_per_minute=10, burst_size=2,
|
||||||
|
))
|
||||||
|
reddit_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig(
|
||||||
|
requests_per_minute=10, burst_size=2,
|
||||||
|
))
|
||||||
|
hackernews_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig(
|
||||||
|
requests_per_minute=15, burst_size=3, # HN Algolia is more generous
|
||||||
|
))
|
||||||
|
|
||||||
|
# --- Content size limits ---
|
||||||
|
max_post_text_chars: int = 2000 # truncate individual posts beyond this
|
||||||
|
max_total_content_bytes: int = 500_000 # ~500KB total data gathered before agent stops
|
||||||
|
|
||||||
|
# --- Timeout ---
|
||||||
|
api_timeout_seconds: float = 15.0
|
||||||
|
|
||||||
|
# --- Credibility thresholds ---
|
||||||
|
min_credibility_score: float = 0.3 # posts below this are flagged/excluded
|
||||||
|
flag_bot_threshold: float = 0.5 # posts between min and this are flagged but included
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> SafetyConfig:
|
||||||
|
"""Build config with env var overrides.
|
||||||
|
|
||||||
|
Env vars: SENTIMENT_MAX_TURNS, SENTIMENT_MAX_BUDGET_USD,
|
||||||
|
SENTIMENT_MAX_API_CALLS, SENTIMENT_MIN_CREDIBILITY.
|
||||||
|
"""
|
||||||
|
kwargs: dict = {}
|
||||||
|
if v := os.environ.get("SENTIMENT_MAX_TURNS"):
|
||||||
|
kwargs["max_turns"] = int(v)
|
||||||
|
if v := os.environ.get("SENTIMENT_MAX_BUDGET_USD"):
|
||||||
|
kwargs["max_budget_usd"] = float(v)
|
||||||
|
if v := os.environ.get("SENTIMENT_MAX_API_CALLS"):
|
||||||
|
kwargs["max_total_api_calls"] = int(v)
|
||||||
|
if v := os.environ.get("SENTIMENT_MIN_CREDIBILITY"):
|
||||||
|
kwargs["min_credibility_score"] = float(v)
|
||||||
|
return cls(**kwargs)
|
||||||
398
agentstuff/sentiment_agent/credibility.py
Normal file
398
agentstuff/sentiment_agent/credibility.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
"""Credibility scoring and bot/disinfo detection.
|
||||||
|
|
||||||
|
Assigns a 0.0–1.0 credibility score to each post based on heuristic signals.
|
||||||
|
Posts below the configured threshold are excluded or flagged so they don't
|
||||||
|
pollute the sentiment analysis.
|
||||||
|
|
||||||
|
Signals are platform-aware — each platform has different indicators of
|
||||||
|
inauthentic behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CredibilityResult:
|
||||||
|
"""Credibility assessment for a single post."""
|
||||||
|
|
||||||
|
score: float # 0.0 (likely bot/disinfo) to 1.0 (likely authentic)
|
||||||
|
flags: list[str] = field(default_factory=list) # human-readable reasons
|
||||||
|
is_excluded: bool = False # below min_credibility_score
|
||||||
|
is_flagged: bool = False # between min and flag threshold
|
||||||
|
|
||||||
|
@property
|
||||||
|
def label(self) -> str:
|
||||||
|
if self.score >= 0.7:
|
||||||
|
return "likely_authentic"
|
||||||
|
if self.score >= 0.5:
|
||||||
|
return "uncertain"
|
||||||
|
if self.score >= 0.3:
|
||||||
|
return "suspicious"
|
||||||
|
return "likely_inauthentic"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Shared heuristics ---
|
||||||
|
|
||||||
|
# Common bot patterns in text
|
||||||
|
_BOT_TEXT_PATTERNS = [
|
||||||
|
# Crypto/scam spam
|
||||||
|
re.compile(r"(?i)(dm me|check my bio|link in bio|click here|free giveaway)"),
|
||||||
|
re.compile(r"(?i)(join my|subscribe to|follow me for|🔥.*🔥.*🔥)"),
|
||||||
|
# Astroturfing phrases
|
||||||
|
re.compile(r"(?i)(i (just )?(discovered|found|tried) this (amazing|incredible|awesome))"),
|
||||||
|
re.compile(r"(?i)(game.?changer|life.?changing|you won'?t believe)"),
|
||||||
|
# Excessive hashtags (5+)
|
||||||
|
re.compile(r"(#\w+\s*){5,}"),
|
||||||
|
# Walls of emojis (10+ consecutive)
|
||||||
|
re.compile(r"[\U0001F300-\U0001FAFF]{10,}"),
|
||||||
|
# Repetitive characters (spammy emphasis)
|
||||||
|
re.compile(r"(.)\1{9,}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Coordinated campaign indicators: identical or near-identical text
|
||||||
|
# This is checked at the batch level, not per-post
|
||||||
|
|
||||||
|
|
||||||
|
def _check_text_patterns(text: str) -> list[str]:
|
||||||
|
"""Check text against common bot/spam patterns."""
|
||||||
|
flags = []
|
||||||
|
for pattern in _BOT_TEXT_PATTERNS:
|
||||||
|
if pattern.search(text):
|
||||||
|
flags.append(f"bot_text_pattern: {pattern.pattern[:60]}")
|
||||||
|
if len(text) < 15:
|
||||||
|
flags.append("very_short_text")
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
def _engagement_ratio_score(
|
||||||
|
likes: int, reposts: int, replies: int
|
||||||
|
) -> tuple[float, list[str]]:
|
||||||
|
"""Score based on engagement ratios.
|
||||||
|
|
||||||
|
Authentic posts tend to have a mix of likes, replies, and reposts.
|
||||||
|
Bot-amplified posts often have inflated likes with very few replies,
|
||||||
|
or massive repost counts with no discussion.
|
||||||
|
"""
|
||||||
|
flags = []
|
||||||
|
total = likes + reposts + replies
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
return 0.5, ["no_engagement"]
|
||||||
|
|
||||||
|
# High repost-to-reply ratio suggests amplification without discussion
|
||||||
|
if reposts > 0 and replies == 0 and reposts > 10:
|
||||||
|
flags.append(f"high_repost_no_replies: {reposts} reposts, 0 replies")
|
||||||
|
return 0.3, flags
|
||||||
|
|
||||||
|
# Extremely high like count with zero replies is suspicious
|
||||||
|
if likes > 100 and replies == 0:
|
||||||
|
flags.append(f"high_likes_no_replies: {likes} likes, 0 replies")
|
||||||
|
return 0.4, flags
|
||||||
|
|
||||||
|
# Normal engagement
|
||||||
|
return min(1.0, 0.5 + (replies / max(total, 1)) * 0.5), flags
|
||||||
|
|
||||||
|
|
||||||
|
# --- Platform-specific scoring ---
|
||||||
|
|
||||||
|
|
||||||
|
def score_bluesky_post(post: dict) -> CredibilityResult:
|
||||||
|
"""Score a Bluesky post for credibility."""
|
||||||
|
score = 1.0
|
||||||
|
flags: list[str] = []
|
||||||
|
|
||||||
|
text = post.get("text", "")
|
||||||
|
handle = post.get("author_handle", "")
|
||||||
|
display_name = post.get("author_display_name", "")
|
||||||
|
likes = post.get("like_count", 0)
|
||||||
|
reposts = post.get("repost_count", 0)
|
||||||
|
replies = post.get("reply_count", 0)
|
||||||
|
|
||||||
|
# Text pattern checks
|
||||||
|
text_flags = _check_text_patterns(text)
|
||||||
|
if text_flags:
|
||||||
|
score -= 0.15 * len(text_flags)
|
||||||
|
flags.extend(text_flags)
|
||||||
|
|
||||||
|
# Handle heuristics
|
||||||
|
# Randomly generated handles (long hex/number strings)
|
||||||
|
if re.match(r"^[a-f0-9]{8,}\.", handle):
|
||||||
|
flags.append(f"random_handle: {handle}")
|
||||||
|
score -= 0.3
|
||||||
|
|
||||||
|
# No display name set
|
||||||
|
if not display_name or display_name == handle:
|
||||||
|
flags.append("no_display_name")
|
||||||
|
score -= 0.1
|
||||||
|
|
||||||
|
# Engagement ratio
|
||||||
|
eng_score, eng_flags = _engagement_ratio_score(likes, reposts, replies)
|
||||||
|
flags.extend(eng_flags)
|
||||||
|
score = score * 0.6 + eng_score * 0.4
|
||||||
|
|
||||||
|
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||||
|
|
||||||
|
|
||||||
|
def score_reddit_post(post: dict) -> CredibilityResult:
|
||||||
|
"""Score a Reddit post for credibility."""
|
||||||
|
score = 1.0
|
||||||
|
flags: list[str] = []
|
||||||
|
|
||||||
|
text = post.get("selftext", "") or post.get("title", "")
|
||||||
|
author = post.get("author", "")
|
||||||
|
upvote_ratio = post.get("upvote_ratio", 0.5)
|
||||||
|
post_score = post.get("score", 0)
|
||||||
|
num_comments = post.get("num_comments", 0)
|
||||||
|
|
||||||
|
# Text patterns
|
||||||
|
text_flags = _check_text_patterns(text)
|
||||||
|
if text_flags:
|
||||||
|
score -= 0.15 * len(text_flags)
|
||||||
|
flags.extend(text_flags)
|
||||||
|
|
||||||
|
# Deleted author
|
||||||
|
if author in ("[deleted]", "[removed]"):
|
||||||
|
flags.append("deleted_author")
|
||||||
|
score -= 0.2
|
||||||
|
|
||||||
|
# Suspicious username patterns (random alphanumeric + numbers)
|
||||||
|
if re.match(r"^[A-Za-z]+[-_]?\d{4,}$", author):
|
||||||
|
flags.append(f"auto_generated_username: {author}")
|
||||||
|
score -= 0.15
|
||||||
|
|
||||||
|
# Very controversial ratio (lots of up AND down votes)
|
||||||
|
if upvote_ratio < 0.4 and post_score > 0:
|
||||||
|
flags.append(f"highly_controversial: {upvote_ratio:.0%} upvote ratio")
|
||||||
|
score -= 0.1
|
||||||
|
|
||||||
|
# High score but zero comments = potential vote manipulation
|
||||||
|
if post_score > 100 and num_comments == 0:
|
||||||
|
flags.append(f"high_score_no_comments: {post_score} score, 0 comments")
|
||||||
|
score -= 0.2
|
||||||
|
|
||||||
|
# Low-effort cross-post spam: very short title, external link, no selftext
|
||||||
|
if (
|
||||||
|
len(post.get("title", "")) < 20
|
||||||
|
and not post.get("is_self", True)
|
||||||
|
and not post.get("selftext")
|
||||||
|
):
|
||||||
|
flags.append("possible_link_spam")
|
||||||
|
score -= 0.1
|
||||||
|
|
||||||
|
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||||
|
|
||||||
|
|
||||||
|
def score_reddit_comment(comment: dict) -> CredibilityResult:
|
||||||
|
"""Score a Reddit comment for credibility."""
|
||||||
|
score = 1.0
|
||||||
|
flags: list[str] = []
|
||||||
|
|
||||||
|
body = comment.get("body", "")
|
||||||
|
author = comment.get("author", "")
|
||||||
|
comment_score = comment.get("score", 0)
|
||||||
|
|
||||||
|
text_flags = _check_text_patterns(body)
|
||||||
|
if text_flags:
|
||||||
|
score -= 0.15 * len(text_flags)
|
||||||
|
flags.extend(text_flags)
|
||||||
|
|
||||||
|
if author in ("[deleted]", "[removed]"):
|
||||||
|
flags.append("deleted_author")
|
||||||
|
score -= 0.2
|
||||||
|
|
||||||
|
if re.match(r"^[A-Za-z]+[-_]?\d{4,}$", author):
|
||||||
|
flags.append(f"auto_generated_username: {author}")
|
||||||
|
score -= 0.15
|
||||||
|
|
||||||
|
# Heavily downvoted
|
||||||
|
if comment_score < -5:
|
||||||
|
flags.append(f"heavily_downvoted: {comment_score}")
|
||||||
|
score -= 0.15
|
||||||
|
|
||||||
|
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||||
|
|
||||||
|
|
||||||
|
def score_hackernews_post(post: dict) -> CredibilityResult:
|
||||||
|
"""Score a HN story for credibility.
|
||||||
|
|
||||||
|
HN is generally higher-signal than social media, but we still check
|
||||||
|
for low-effort submissions and spammy patterns.
|
||||||
|
"""
|
||||||
|
score = 1.0
|
||||||
|
flags: list[str] = []
|
||||||
|
|
||||||
|
title = post.get("title", "")
|
||||||
|
text = post.get("story_text", "") or title
|
||||||
|
points = post.get("points", 0)
|
||||||
|
num_comments = post.get("num_comments", 0)
|
||||||
|
|
||||||
|
text_flags = _check_text_patterns(text)
|
||||||
|
if text_flags:
|
||||||
|
score -= 0.1 * len(text_flags)
|
||||||
|
flags.extend(text_flags)
|
||||||
|
|
||||||
|
# Zero points = the community didn't find it valuable
|
||||||
|
if points == 0:
|
||||||
|
flags.append("zero_points")
|
||||||
|
score -= 0.1
|
||||||
|
|
||||||
|
# HN is generally more credible, start with a bonus
|
||||||
|
score = min(1.0, score + 0.1)
|
||||||
|
|
||||||
|
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||||
|
|
||||||
|
|
||||||
|
def score_hackernews_comment(comment: dict) -> CredibilityResult:
|
||||||
|
"""Score a HN comment for credibility."""
|
||||||
|
score = 1.0
|
||||||
|
flags: list[str] = []
|
||||||
|
|
||||||
|
text = comment.get("comment_text", "")
|
||||||
|
|
||||||
|
text_flags = _check_text_patterns(text)
|
||||||
|
if text_flags:
|
||||||
|
score -= 0.1 * len(text_flags)
|
||||||
|
flags.extend(text_flags)
|
||||||
|
|
||||||
|
# HN comments are generally higher quality
|
||||||
|
score = min(1.0, score + 0.1)
|
||||||
|
|
||||||
|
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Batch-level coordination detection ---
|
||||||
|
|
||||||
|
|
||||||
|
def detect_coordination(posts: list[dict], text_key: str = "text") -> list[str]:
|
||||||
|
"""Detect coordinated inauthentic behavior across a batch of posts.
|
||||||
|
|
||||||
|
Looks for:
|
||||||
|
- Duplicate or near-duplicate text (copy-paste campaigns)
|
||||||
|
- Burst posting (many posts in a very short window)
|
||||||
|
- Same talking points with minor variations
|
||||||
|
|
||||||
|
Returns a list of warning strings.
|
||||||
|
"""
|
||||||
|
warnings: list[str] = []
|
||||||
|
texts = [p.get(text_key, "") for p in posts if p.get(text_key)]
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
# Exact duplicates
|
||||||
|
seen: dict[str, int] = {}
|
||||||
|
for t in texts:
|
||||||
|
normalized = t.strip().lower()
|
||||||
|
seen[normalized] = seen.get(normalized, 0) + 1
|
||||||
|
|
||||||
|
duplicates = {text: count for text, count in seen.items() if count > 1}
|
||||||
|
if duplicates:
|
||||||
|
total_dupes = sum(duplicates.values())
|
||||||
|
warnings.append(
|
||||||
|
f"COORDINATION WARNING: {len(duplicates)} duplicate texts found "
|
||||||
|
f"({total_dupes} total copies). Possible copy-paste campaign."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Near-duplicates: check if many posts share a long common substring
|
||||||
|
# (simplified: check if >30% of posts start with the same 50+ chars)
|
||||||
|
if len(texts) >= 5:
|
||||||
|
prefixes: dict[str, int] = {}
|
||||||
|
for t in texts:
|
||||||
|
prefix = t.strip().lower()[:80]
|
||||||
|
if len(prefix) >= 50:
|
||||||
|
prefixes[prefix] = prefixes.get(prefix, 0) + 1
|
||||||
|
|
||||||
|
for prefix, count in prefixes.items():
|
||||||
|
if count >= len(texts) * 0.3:
|
||||||
|
warnings.append(
|
||||||
|
f"COORDINATION WARNING: {count}/{len(texts)} posts share "
|
||||||
|
f"a common prefix ({prefix[:50]}...). Possible template campaign."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Burst detection: if timestamps are available
|
||||||
|
timestamps = []
|
||||||
|
for p in posts:
|
||||||
|
created = p.get("created_at") or p.get("created_utc")
|
||||||
|
if isinstance(created, str):
|
||||||
|
try:
|
||||||
|
timestamps.append(datetime.fromisoformat(created.replace("Z", "+00:00")))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
elif isinstance(created, (int, float)):
|
||||||
|
timestamps.append(datetime.fromtimestamp(created, tz=timezone.utc))
|
||||||
|
|
||||||
|
if len(timestamps) >= 5:
|
||||||
|
timestamps.sort()
|
||||||
|
# Check if >50% of posts landed within a 5-minute window
|
||||||
|
window_seconds = 300
|
||||||
|
for i in range(len(timestamps) - 2):
|
||||||
|
window_end = timestamps[i] + __import__("datetime").timedelta(seconds=window_seconds)
|
||||||
|
in_window = sum(1 for t in timestamps if timestamps[i] <= t <= window_end)
|
||||||
|
if in_window >= len(timestamps) * 0.5:
|
||||||
|
warnings.append(
|
||||||
|
f"COORDINATION WARNING: {in_window}/{len(timestamps)} posts "
|
||||||
|
f"appeared within a 5-minute window. Possible coordinated posting."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
|
||||||
|
def filter_and_annotate(
|
||||||
|
posts: list[dict],
|
||||||
|
scorer,
|
||||||
|
min_score: float = 0.3,
|
||||||
|
flag_threshold: float = 0.5,
|
||||||
|
) -> tuple[list[dict], dict]:
|
||||||
|
"""Score all posts, filter out low-credibility ones, and annotate the rest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
posts: List of post dicts from any platform.
|
||||||
|
scorer: A scoring function (e.g., score_reddit_post).
|
||||||
|
min_score: Posts below this are excluded.
|
||||||
|
flag_threshold: Posts between min_score and this are flagged.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filtered_posts, stats_dict).
|
||||||
|
Each post in filtered_posts gets a "_credibility" key added.
|
||||||
|
"""
|
||||||
|
filtered = []
|
||||||
|
stats = {
|
||||||
|
"total": len(posts),
|
||||||
|
"excluded": 0,
|
||||||
|
"flagged": 0,
|
||||||
|
"authentic": 0,
|
||||||
|
"excluded_reasons": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for post in posts:
|
||||||
|
result = scorer(post)
|
||||||
|
result.is_excluded = result.score < min_score
|
||||||
|
result.is_flagged = min_score <= result.score < flag_threshold
|
||||||
|
|
||||||
|
if result.is_excluded:
|
||||||
|
stats["excluded"] += 1
|
||||||
|
stats["excluded_reasons"].append(
|
||||||
|
{"score": round(result.score, 2), "flags": result.flags}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
post["_credibility"] = {
|
||||||
|
"score": round(result.score, 2),
|
||||||
|
"label": result.label,
|
||||||
|
"flags": result.flags,
|
||||||
|
"is_flagged": result.is_flagged,
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.is_flagged:
|
||||||
|
stats["flagged"] += 1
|
||||||
|
else:
|
||||||
|
stats["authentic"] += 1
|
||||||
|
|
||||||
|
filtered.append(post)
|
||||||
|
|
||||||
|
return filtered, stats
|
||||||
66
agentstuff/sentiment_agent/main.py
Normal file
66
agentstuff/sentiment_agent/main.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""CLI entry point for the sentiment analysis agent."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from sentiment_agent.agent import run_sentiment_analysis
|
||||||
|
from sentiment_agent.config import SafetyConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def async_main(args: argparse.Namespace) -> None:
|
||||||
|
config = SafetyConfig(
|
||||||
|
max_turns=args.max_turns,
|
||||||
|
max_budget_usd=args.max_budget,
|
||||||
|
max_total_api_calls=args.max_api_calls,
|
||||||
|
min_credibility_score=args.min_credibility,
|
||||||
|
flag_bot_threshold=args.flag_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await run_sentiment_analysis(
|
||||||
|
topic=args.topic,
|
||||||
|
sources=args.sources,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SENTIMENT ANALYSIS REPORT")
|
||||||
|
print("=" * 60)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run sentiment analysis on a topic with bot/disinfo detection",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("topic", help="The topic to analyze sentiment for")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sources", nargs="*", help="Specific URLs or sources to also analyze"
|
||||||
|
)
|
||||||
|
|
||||||
|
safety = parser.add_argument_group("safety limits")
|
||||||
|
safety.add_argument(
|
||||||
|
"--max-turns", type=int, default=20, help="Max agent turns"
|
||||||
|
)
|
||||||
|
safety.add_argument(
|
||||||
|
"--max-budget", type=float, default=0.50, help="Max Claude API spend (USD)"
|
||||||
|
)
|
||||||
|
safety.add_argument(
|
||||||
|
"--max-api-calls", type=int, default=50, help="Max total API calls across all platforms"
|
||||||
|
)
|
||||||
|
|
||||||
|
credibility = parser.add_argument_group("credibility filtering")
|
||||||
|
credibility.add_argument(
|
||||||
|
"--min-credibility", type=float, default=0.3,
|
||||||
|
help="Posts below this score are excluded (0.0-1.0)",
|
||||||
|
)
|
||||||
|
credibility.add_argument(
|
||||||
|
"--flag-threshold", type=float, default=0.5,
|
||||||
|
help="Posts between min and this are flagged but included (0.0-1.0)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
anyio.run(async_main, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
169
agentstuff/sentiment_agent/ratelimit.py
Normal file
169
agentstuff/sentiment_agent/ratelimit.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Rate limiter and API call budget tracker.
|
||||||
|
|
||||||
|
Enforces per-platform rate limits and a global call budget so the agent
|
||||||
|
can't hammer APIs or run up unbounded costs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from sentiment_agent.config import RateLimitConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetExhaustedError(Exception):
|
||||||
|
"""Raised when the global API call budget is spent."""
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitExceededError(Exception):
|
||||||
|
"""Raised when a platform's rate limit is hit and cooldown hasn't elapsed."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _PlatformState:
|
||||||
|
"""Tracks call timestamps and active request count for one platform."""
|
||||||
|
|
||||||
|
config: RateLimitConfig
|
||||||
|
call_timestamps: list[float] = field(default_factory=list)
|
||||||
|
active_requests: int = 0
|
||||||
|
last_429_at: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Manages rate limiting across all platforms + a global call budget.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
limiter = RateLimiter(max_total_calls=50)
|
||||||
|
limiter.register_platform("reddit", RateLimitConfig(...))
|
||||||
|
|
||||||
|
async with limiter.acquire("reddit"):
|
||||||
|
await do_reddit_call()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_total_calls: int = 50):
|
||||||
|
self._max_total = max_total_calls
|
||||||
|
self._total_calls = 0
|
||||||
|
self._platforms: dict[str, _PlatformState] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_calls(self) -> int:
|
||||||
|
return self._total_calls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining_calls(self) -> int:
|
||||||
|
return max(0, self._max_total - self._total_calls)
|
||||||
|
|
||||||
|
def register_platform(self, name: str, config: RateLimitConfig) -> None:
|
||||||
|
self._platforms[name] = _PlatformState(config=config)
|
||||||
|
|
||||||
|
def acquire(self, platform: str) -> _AcquireContext:
|
||||||
|
"""Context manager that enforces rate limits before allowing a call."""
|
||||||
|
return _AcquireContext(self, platform)
|
||||||
|
|
||||||
|
async def _acquire(self, platform: str) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if self._total_calls >= self._max_total:
|
||||||
|
raise BudgetExhaustedError(
|
||||||
|
f"Global API call budget exhausted ({self._max_total} calls). "
|
||||||
|
"Increase max_total_api_calls in SafetyConfig to allow more."
|
||||||
|
)
|
||||||
|
|
||||||
|
state = self._platforms.get(platform)
|
||||||
|
if not state:
|
||||||
|
raise ValueError(f"Platform '{platform}' not registered with rate limiter")
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
# Check 429 cooldown
|
||||||
|
if state.last_429_at:
|
||||||
|
elapsed = now - state.last_429_at
|
||||||
|
if elapsed < state.config.cooldown_after_429:
|
||||||
|
remaining = state.config.cooldown_after_429 - elapsed
|
||||||
|
raise RateLimitExceededError(
|
||||||
|
f"Platform '{platform}' is in cooldown after 429. "
|
||||||
|
f"Try again in {remaining:.0f}s."
|
||||||
|
)
|
||||||
|
state.last_429_at = 0.0
|
||||||
|
|
||||||
|
# Check burst limit
|
||||||
|
if state.active_requests >= state.config.burst_size:
|
||||||
|
raise RateLimitExceededError(
|
||||||
|
f"Platform '{platform}' burst limit reached "
|
||||||
|
f"({state.config.burst_size} concurrent). Wait for a request to finish."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check RPM: discard timestamps older than 60s, then check count
|
||||||
|
cutoff = now - 60.0
|
||||||
|
state.call_timestamps = [t for t in state.call_timestamps if t > cutoff]
|
||||||
|
|
||||||
|
if len(state.call_timestamps) >= state.config.requests_per_minute:
|
||||||
|
oldest = state.call_timestamps[0]
|
||||||
|
wait_time = 60.0 - (now - oldest)
|
||||||
|
raise RateLimitExceededError(
|
||||||
|
f"Platform '{platform}' rate limit: {state.config.requests_per_minute}/min. "
|
||||||
|
f"Try again in {wait_time:.0f}s."
|
||||||
|
)
|
||||||
|
|
||||||
|
# All clear — record the call
|
||||||
|
state.call_timestamps.append(now)
|
||||||
|
state.active_requests += 1
|
||||||
|
self._total_calls += 1
|
||||||
|
|
||||||
|
async def _release(self, platform: str) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
state = self._platforms.get(platform)
|
||||||
|
if state:
|
||||||
|
state.active_requests = max(0, state.active_requests - 1)
|
||||||
|
|
||||||
|
def record_429(self, platform: str) -> None:
|
||||||
|
"""Call this when an API returns 429 to trigger cooldown."""
|
||||||
|
state = self._platforms.get(platform)
|
||||||
|
if state:
|
||||||
|
state.last_429_at = time.monotonic()
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Return current usage stats for logging/reporting."""
|
||||||
|
stats: dict = {
|
||||||
|
"total_calls": self._total_calls,
|
||||||
|
"remaining_calls": self.remaining_calls,
|
||||||
|
"platforms": {},
|
||||||
|
}
|
||||||
|
for name, state in self._platforms.items():
|
||||||
|
now = time.monotonic()
|
||||||
|
cutoff = now - 60.0
|
||||||
|
recent = [t for t in state.call_timestamps if t > cutoff]
|
||||||
|
stats["platforms"][name] = {
|
||||||
|
"calls_last_60s": len(recent),
|
||||||
|
"active_requests": state.active_requests,
|
||||||
|
"rpm_limit": state.config.requests_per_minute,
|
||||||
|
"in_cooldown": bool(
|
||||||
|
state.last_429_at
|
||||||
|
and (now - state.last_429_at) < state.config.cooldown_after_429
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
class _AcquireContext:
|
||||||
|
"""Async context manager for rate-limited API calls."""
|
||||||
|
|
||||||
|
def __init__(self, limiter: RateLimiter, platform: str):
|
||||||
|
self._limiter = limiter
|
||||||
|
self._platform = platform
|
||||||
|
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
await self._limiter._acquire(self._platform)
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc_info) -> None:
|
||||||
|
# Check if the call got a 429
|
||||||
|
if exc_info[0] is not None:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
exc = exc_info[1]
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429:
|
||||||
|
self._limiter.record_429(self._platform)
|
||||||
|
|
||||||
|
await self._limiter._release(self._platform)
|
||||||
352
agentstuff/sentiment_agent/tools.py
Normal file
352
agentstuff/sentiment_agent/tools.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
"""Custom MCP tools for social media and forum data gathering.
|
||||||
|
|
||||||
|
Each tool wraps an API client, enforces rate limits, runs credibility
|
||||||
|
scoring, and returns MCP-formatted results with bot/disinfo annotations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from claude_agent_sdk import tool, create_sdk_mcp_server
|
||||||
|
|
||||||
|
from sentiment_agent.clients import bluesky, reddit, hackernews
|
||||||
|
from sentiment_agent.config import SafetyConfig
|
||||||
|
from sentiment_agent.credibility import (
|
||||||
|
detect_coordination,
|
||||||
|
filter_and_annotate,
|
||||||
|
score_bluesky_post,
|
||||||
|
score_hackernews_comment,
|
||||||
|
score_hackernews_post,
|
||||||
|
score_reddit_comment,
|
||||||
|
score_reddit_post,
|
||||||
|
)
|
||||||
|
from sentiment_agent.ratelimit import BudgetExhaustedError, RateLimiter
|
||||||
|
|
||||||
|
# Module-level state — initialized by create_social_tools_server()
|
||||||
|
_limiter: RateLimiter | None = None
|
||||||
|
_config: SafetyConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_limiter() -> RateLimiter:
|
||||||
|
if _limiter is None:
|
||||||
|
raise RuntimeError("Tools not initialized — call create_social_tools_server() first")
|
||||||
|
return _limiter
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config() -> SafetyConfig:
|
||||||
|
if _config is None:
|
||||||
|
return SafetyConfig()
|
||||||
|
return _config
|
||||||
|
|
||||||
|
|
||||||
|
def _text_result(text: str) -> dict:
|
||||||
|
return {"content": [{"type": "text", "text": text}]}
|
||||||
|
|
||||||
|
|
||||||
|
def _error_result(error: str) -> dict:
|
||||||
|
return {"content": [{"type": "text", "text": f"Error: {error}"}], "isError": True}
|
||||||
|
|
||||||
|
|
||||||
|
def _clamp_limit(requested: int) -> int:
|
||||||
|
"""Enforce max results per call."""
|
||||||
|
return min(requested, _get_config().max_results_per_call)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_with_stats(
|
||||||
|
posts: list[dict],
|
||||||
|
stats: dict,
|
||||||
|
coordination_warnings: list[str],
|
||||||
|
platform: str,
|
||||||
|
) -> str:
|
||||||
|
"""Format results with credibility stats prepended."""
|
||||||
|
header_parts = [
|
||||||
|
f"Platform: {platform}",
|
||||||
|
f"Results: {stats['authentic']} authentic, {stats['flagged']} flagged, "
|
||||||
|
f"{stats['excluded']} excluded (of {stats['total']} total)",
|
||||||
|
]
|
||||||
|
if coordination_warnings:
|
||||||
|
header_parts.append("--- COORDINATION ALERTS ---")
|
||||||
|
header_parts.extend(coordination_warnings)
|
||||||
|
header_parts.append("---")
|
||||||
|
|
||||||
|
limiter = _get_limiter()
|
||||||
|
header_parts.append(f"API budget remaining: {limiter.remaining_calls} calls")
|
||||||
|
|
||||||
|
header = "\n".join(header_parts)
|
||||||
|
body = json.dumps(posts, indent=2, default=str)
|
||||||
|
return f"{header}\n\n{body}"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Bluesky tools ---
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
"search_bluesky",
|
||||||
|
"Search Bluesky for posts about a topic. Returns posts with text, author, "
|
||||||
|
"engagement metrics, credibility scores, and bot/disinfo flags. "
|
||||||
|
"Requires BLUESKY_HANDLE and BLUESKY_APP_PASSWORD env vars.",
|
||||||
|
{"query": str, "limit": int, "sort": str},
|
||||||
|
)
|
||||||
|
async def search_bluesky(args: dict) -> dict:
|
||||||
|
try:
|
||||||
|
limiter = _get_limiter()
|
||||||
|
config = _get_config()
|
||||||
|
|
||||||
|
async with limiter.acquire("bluesky"):
|
||||||
|
posts = await bluesky.search_posts(
|
||||||
|
query=args["query"],
|
||||||
|
limit=_clamp_limit(args.get("limit", 25)),
|
||||||
|
sort=args.get("sort", "top"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not posts:
|
||||||
|
return _text_result(f"No Bluesky posts found for: {args['query']}")
|
||||||
|
|
||||||
|
coordination = detect_coordination(posts, text_key="text")
|
||||||
|
filtered, stats = filter_and_annotate(
|
||||||
|
posts, score_bluesky_post,
|
||||||
|
min_score=config.min_credibility_score,
|
||||||
|
flag_threshold=config.flag_bot_threshold,
|
||||||
|
)
|
||||||
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Bluesky"))
|
||||||
|
except BudgetExhaustedError as e:
|
||||||
|
return _error_result(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
return _error_result(f"Bluesky search failed: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
"get_bluesky_thread",
|
||||||
|
"Fetch a Bluesky thread/post and its replies with credibility scoring. "
|
||||||
|
"Accepts an at:// URI or https://bsky.app/... URL.",
|
||||||
|
{"uri": str, "depth": int},
|
||||||
|
)
|
||||||
|
async def get_bluesky_thread(args: dict) -> dict:
|
||||||
|
try:
|
||||||
|
limiter = _get_limiter()
|
||||||
|
config = _get_config()
|
||||||
|
|
||||||
|
async with limiter.acquire("bluesky"):
|
||||||
|
thread = await bluesky.get_thread(
|
||||||
|
uri=args["uri"],
|
||||||
|
depth=args.get("depth", 6),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score replies
|
||||||
|
if thread.get("replies"):
|
||||||
|
coordination = detect_coordination(thread["replies"], text_key="text")
|
||||||
|
filtered_replies, stats = filter_and_annotate(
|
||||||
|
thread["replies"], score_bluesky_post,
|
||||||
|
min_score=config.min_credibility_score,
|
||||||
|
flag_threshold=config.flag_bot_threshold,
|
||||||
|
)
|
||||||
|
thread["replies"] = filtered_replies
|
||||||
|
thread["_reply_credibility_stats"] = stats
|
||||||
|
thread["_coordination_warnings"] = coordination
|
||||||
|
|
||||||
|
# Score root post
|
||||||
|
if thread.get("post"):
|
||||||
|
result = score_bluesky_post(thread["post"])
|
||||||
|
thread["post"]["_credibility"] = {
|
||||||
|
"score": round(result.score, 2),
|
||||||
|
"label": result.label,
|
||||||
|
"flags": result.flags,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _text_result(json.dumps(thread, indent=2, default=str))
|
||||||
|
except BudgetExhaustedError as e:
|
||||||
|
return _error_result(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
return _error_result(f"Bluesky thread fetch failed: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Reddit tools ---
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
"search_reddit",
|
||||||
|
"Search Reddit for posts about a topic. Returns posts with credibility scores "
|
||||||
|
"and bot/disinfo flags. Posts below the credibility threshold are auto-excluded. "
|
||||||
|
"Use subreddit='all' for site-wide or specify a subreddit name.",
|
||||||
|
{"query": str, "subreddit": str, "sort": str, "time_filter": str, "limit": int},
|
||||||
|
)
|
||||||
|
async def search_reddit_tool(args: dict) -> dict:
|
||||||
|
try:
|
||||||
|
limiter = _get_limiter()
|
||||||
|
config = _get_config()
|
||||||
|
|
||||||
|
async with limiter.acquire("reddit"):
|
||||||
|
posts = await reddit.search_posts(
|
||||||
|
query=args["query"],
|
||||||
|
subreddit=args.get("subreddit", "all"),
|
||||||
|
sort=args.get("sort", "relevance"),
|
||||||
|
time_filter=args.get("time_filter", "month"),
|
||||||
|
limit=_clamp_limit(args.get("limit", 25)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not posts:
|
||||||
|
return _text_result(f"No Reddit posts found for: {args['query']}")
|
||||||
|
|
||||||
|
coordination = detect_coordination(posts, text_key="title")
|
||||||
|
filtered, stats = filter_and_annotate(
|
||||||
|
posts, score_reddit_post,
|
||||||
|
min_score=config.min_credibility_score,
|
||||||
|
flag_threshold=config.flag_bot_threshold,
|
||||||
|
)
|
||||||
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit"))
|
||||||
|
except BudgetExhaustedError as e:
|
||||||
|
return _error_result(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
return _error_result(f"Reddit search failed: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
"get_reddit_comments",
|
||||||
|
"Fetch comments for a Reddit post with credibility scoring. "
|
||||||
|
"Pass the permalink path or full URL.",
|
||||||
|
{"permalink": str, "sort": str, "limit": int},
|
||||||
|
)
|
||||||
|
async def get_reddit_comments(args: dict) -> dict:
|
||||||
|
try:
|
||||||
|
limiter = _get_limiter()
|
||||||
|
config = _get_config()
|
||||||
|
|
||||||
|
async with limiter.acquire("reddit"):
|
||||||
|
comments = await reddit.get_post_comments(
|
||||||
|
permalink=args["permalink"],
|
||||||
|
sort=args.get("sort", "top"),
|
||||||
|
limit=_clamp_limit(args.get("limit", 25)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not comments:
|
||||||
|
return _text_result("No comments found for this post.")
|
||||||
|
|
||||||
|
coordination = detect_coordination(comments, text_key="body")
|
||||||
|
filtered, stats = filter_and_annotate(
|
||||||
|
comments, score_reddit_comment,
|
||||||
|
min_score=config.min_credibility_score,
|
||||||
|
flag_threshold=config.flag_bot_threshold,
|
||||||
|
)
|
||||||
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit Comments"))
|
||||||
|
except BudgetExhaustedError as e:
|
||||||
|
return _error_result(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
return _error_result(f"Reddit comments fetch failed: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Hacker News tools ---
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
"search_hackernews",
|
||||||
|
"Search Hacker News for stories with credibility scoring. "
|
||||||
|
"No authentication required.",
|
||||||
|
{"query": str, "limit": int},
|
||||||
|
)
|
||||||
|
async def search_hackernews_tool(args: dict) -> dict:
|
||||||
|
try:
|
||||||
|
limiter = _get_limiter()
|
||||||
|
config = _get_config()
|
||||||
|
|
||||||
|
async with limiter.acquire("hackernews"):
|
||||||
|
stories = await hackernews.search_stories(
|
||||||
|
query=args["query"],
|
||||||
|
limit=_clamp_limit(args.get("limit", 25)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stories:
|
||||||
|
return _text_result(f"No HN stories found for: {args['query']}")
|
||||||
|
|
||||||
|
coordination = detect_coordination(stories, text_key="title")
|
||||||
|
filtered, stats = filter_and_annotate(
|
||||||
|
stories, score_hackernews_post,
|
||||||
|
min_score=config.min_credibility_score,
|
||||||
|
flag_threshold=config.flag_bot_threshold,
|
||||||
|
)
|
||||||
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Hacker News"))
|
||||||
|
except BudgetExhaustedError as e:
|
||||||
|
return _error_result(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
return _error_result(f"HN search failed: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
"search_hackernews_comments",
|
||||||
|
"Search Hacker News comments for opinions and discussions with credibility scoring.",
|
||||||
|
{"query": str, "limit": int},
|
||||||
|
)
|
||||||
|
async def search_hackernews_comments(args: dict) -> dict:
|
||||||
|
try:
|
||||||
|
limiter = _get_limiter()
|
||||||
|
config = _get_config()
|
||||||
|
|
||||||
|
async with limiter.acquire("hackernews"):
|
||||||
|
comments = await hackernews.search_comments(
|
||||||
|
query=args["query"],
|
||||||
|
limit=_clamp_limit(args.get("limit", 25)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not comments:
|
||||||
|
return _text_result(f"No HN comments found for: {args['query']}")
|
||||||
|
|
||||||
|
coordination = detect_coordination(comments, text_key="comment_text")
|
||||||
|
filtered, stats = filter_and_annotate(
|
||||||
|
comments, score_hackernews_comment,
|
||||||
|
min_score=config.min_credibility_score,
|
||||||
|
flag_threshold=config.flag_bot_threshold,
|
||||||
|
)
|
||||||
|
return _text_result(
|
||||||
|
_format_with_stats(filtered, stats, coordination, "HN Comments")
|
||||||
|
)
|
||||||
|
except BudgetExhaustedError as e:
|
||||||
|
return _error_result(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
return _error_result(f"HN comment search failed: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Budget status tool ---
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
"get_api_budget_status",
|
||||||
|
"Check remaining API call budget, rate limit status, and per-platform stats. "
|
||||||
|
"Use this before making more API calls to avoid hitting limits.",
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
async def get_api_budget_status(args: dict) -> dict:
|
||||||
|
limiter = _get_limiter()
|
||||||
|
stats = limiter.get_stats()
|
||||||
|
return _text_result(json.dumps(stats, indent=2, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
# --- Server factory ---
|
||||||
|
|
||||||
|
|
||||||
|
def create_social_tools_server(config: SafetyConfig | None = None):
|
||||||
|
"""Create an MCP server with all social media/forum tools.
|
||||||
|
|
||||||
|
Initializes rate limiting and credibility thresholds from config.
|
||||||
|
"""
|
||||||
|
global _limiter, _config
|
||||||
|
|
||||||
|
_config = config or SafetyConfig.from_env()
|
||||||
|
|
||||||
|
_limiter = RateLimiter(max_total_calls=_config.max_total_api_calls)
|
||||||
|
_limiter.register_platform("bluesky", _config.bluesky_rate)
|
||||||
|
_limiter.register_platform("reddit", _config.reddit_rate)
|
||||||
|
_limiter.register_platform("hackernews", _config.hackernews_rate)
|
||||||
|
|
||||||
|
return create_sdk_mcp_server(
|
||||||
|
"social-tools",
|
||||||
|
tools=[
|
||||||
|
search_bluesky,
|
||||||
|
get_bluesky_thread,
|
||||||
|
search_reddit_tool,
|
||||||
|
get_reddit_comments,
|
||||||
|
search_hackernews_tool,
|
||||||
|
search_hackernews_comments,
|
||||||
|
get_api_budget_status,
|
||||||
|
],
|
||||||
|
)
|
||||||
@@ -1119,7 +1119,7 @@ def _run_encode_job(job_id: str, encode_params: dict) -> None:
|
|||||||
|
|
||||||
filename = encode_result.filename
|
filename = encode_result.filename
|
||||||
if not filename:
|
if not filename:
|
||||||
filename = generate_filename("stego", output_ext)
|
filename = generate_filename(prefix="stego", extension=output_ext.lstrip("."))
|
||||||
elif embed_mode == "dct" and dct_output_format == "jpeg" and filename.endswith(".png"):
|
elif embed_mode == "dct" and dct_output_format == "jpeg" and filename.endswith(".png"):
|
||||||
filename = filename[:-4] + ".jpg"
|
filename = filename[:-4] + ".jpg"
|
||||||
|
|
||||||
@@ -1210,7 +1210,7 @@ def _run_encode_audio_job(job_id: str, encode_params: dict) -> None:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
filename = generate_filename("stego_audio", ".wav")
|
filename = generate_filename(prefix="stego_audio", extension="wav")
|
||||||
file_id = secrets.token_urlsafe(16)
|
file_id = secrets.token_urlsafe(16)
|
||||||
temp_storage.save_temp_file(
|
temp_storage.save_temp_file(
|
||||||
file_id,
|
file_id,
|
||||||
@@ -1273,6 +1273,9 @@ def encode_page():
|
|||||||
|
|
||||||
if carrier_type == "audio":
|
if carrier_type == "audio":
|
||||||
# ========== AUDIO ENCODE PATH (v4.3.0) ==========
|
# ========== AUDIO ENCODE PATH (v4.3.0) ==========
|
||||||
|
# Audio carrier uses a separate form field to avoid name collision
|
||||||
|
carrier = request.files.get("audio_carrier") or carrier
|
||||||
|
|
||||||
if not HAS_AUDIO_SUPPORT:
|
if not HAS_AUDIO_SUPPORT:
|
||||||
return _error_response(
|
return _error_response(
|
||||||
"Audio steganography is not available. Install audio dependencies."
|
"Audio steganography is not available. Install audio dependencies."
|
||||||
@@ -1439,7 +1442,7 @@ def encode_page():
|
|||||||
error_msg = encode_result.error or "Audio encoding failed"
|
error_msg = encode_result.error or "Audio encoding failed"
|
||||||
return _error_response(error_msg)
|
return _error_response(error_msg)
|
||||||
|
|
||||||
filename = generate_filename("stego_audio", ".wav")
|
filename = generate_filename(prefix="stego_audio", extension="wav")
|
||||||
file_id = secrets.token_urlsafe(16)
|
file_id = secrets.token_urlsafe(16)
|
||||||
cleanup_temp_files()
|
cleanup_temp_files()
|
||||||
temp_storage.save_temp_file(
|
temp_storage.save_temp_file(
|
||||||
@@ -1676,7 +1679,7 @@ def encode_page():
|
|||||||
# Use filename from result or generate one
|
# Use filename from result or generate one
|
||||||
filename = encode_result.filename
|
filename = encode_result.filename
|
||||||
if not filename:
|
if not filename:
|
||||||
filename = generate_filename("stego", output_ext)
|
filename = generate_filename(prefix="stego", extension=output_ext.lstrip("."))
|
||||||
elif embed_mode == "dct" and dct_output_format == "jpeg" and filename.endswith(".png"):
|
elif embed_mode == "dct" and dct_output_format == "jpeg" and filename.endswith(".png"):
|
||||||
filename = filename[:-4] + ".jpg"
|
filename = filename[:-4] + ".jpg"
|
||||||
|
|
||||||
@@ -2029,6 +2032,9 @@ def decode_page():
|
|||||||
|
|
||||||
if carrier_type == "audio":
|
if carrier_type == "audio":
|
||||||
# ========== AUDIO DECODE PATH (v4.3.0) ==========
|
# ========== AUDIO DECODE PATH (v4.3.0) ==========
|
||||||
|
# Audio stego uses a separate form field to avoid name collision
|
||||||
|
stego_image = request.files.get("stego_audio") or stego_image
|
||||||
|
|
||||||
if not HAS_AUDIO_SUPPORT:
|
if not HAS_AUDIO_SUPPORT:
|
||||||
flash("Audio steganography is not available.", "error")
|
flash("Audio steganography is not available.", "error")
|
||||||
return render_template("decode.html", has_qrcode_read=HAS_QRCODE_READ)
|
return render_template("decode.html", has_qrcode_read=HAS_QRCODE_READ)
|
||||||
|
|||||||
@@ -974,13 +974,13 @@ const Stegasoo = {
|
|||||||
body: formData,
|
body: formData,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const result = await response.json().catch(() => null);
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error('Failed to start encode');
|
throw new Error((result && result.error) || 'Failed to start encode');
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await response.json();
|
if (result && result.error) {
|
||||||
|
|
||||||
if (result.error) {
|
|
||||||
throw new Error(result.error);
|
throw new Error(result.error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -280,7 +280,7 @@
|
|||||||
<i class="bi bi-file-earmark-music me-1"></i> Stego Audio
|
<i class="bi bi-file-earmark-music me-1"></i> Stego Audio
|
||||||
</label>
|
</label>
|
||||||
<div class="drop-zone pixel-container" id="audioStegoDropZone">
|
<div class="drop-zone pixel-container" id="audioStegoDropZone">
|
||||||
<input type="file" name="stego_image" accept="audio/*" id="audioStegoInput">
|
<input type="file" name="stego_audio" accept="audio/*" id="audioStegoInput">
|
||||||
<div class="drop-zone-label">
|
<div class="drop-zone-label">
|
||||||
<i class="bi bi-music-note-beamed fs-3 d-block mb-2 text-muted"></i>
|
<i class="bi bi-music-note-beamed fs-3 d-block mb-2 text-muted"></i>
|
||||||
<span class="text-muted">Drop audio or click</span>
|
<span class="text-muted">Drop audio or click</span>
|
||||||
|
|||||||
@@ -240,6 +240,11 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Capacity Warning -->
|
||||||
|
<div class="form-text text-danger d-none" id="capacityWarning">
|
||||||
|
<i class="bi bi-exclamation-triangle-fill me-1"></i><span id="capacityWarningText"></span>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Mode & Carrier Type toggles (aligned row) -->
|
<!-- Mode & Carrier Type toggles (aligned row) -->
|
||||||
<div class="row">
|
<div class="row">
|
||||||
<div class="col-md-6">
|
<div class="col-md-6">
|
||||||
@@ -541,6 +546,82 @@ const audioModeGroup = document.getElementById('audioModeGroup');
|
|||||||
const capacityPanel = document.getElementById('capacityPanel');
|
const capacityPanel = document.getElementById('capacityPanel');
|
||||||
const audioCapacityPanel = document.getElementById('audioCapacityPanel');
|
const audioCapacityPanel = document.getElementById('audioCapacityPanel');
|
||||||
|
|
||||||
|
// Capacity tracking for client-side payload size validation
|
||||||
|
let capacityBytes = { dct: 0, lsb: 0, audio_lsb: 0, audio_spread: 0 };
|
||||||
|
|
||||||
|
function checkCapacity() {
|
||||||
|
const warning = document.getElementById('capacityWarning');
|
||||||
|
const warningText = document.getElementById('capacityWarningText');
|
||||||
|
const encodeBtn = document.getElementById('encodeBtn');
|
||||||
|
if (!warning || !warningText || !encodeBtn) return;
|
||||||
|
|
||||||
|
// Determine payload size
|
||||||
|
const isText = document.getElementById('payloadText')?.checked;
|
||||||
|
let payloadSize = 0;
|
||||||
|
if (isText) {
|
||||||
|
const msg = document.getElementById('messageInput')?.value || '';
|
||||||
|
if (msg) payloadSize = new Blob([msg]).size;
|
||||||
|
} else {
|
||||||
|
const file = document.getElementById('payloadFileInput')?.files[0];
|
||||||
|
if (file) payloadSize = file.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get active mode
|
||||||
|
const mode = document.querySelector('input[name="embed_mode"]:checked')?.value || 'lsb';
|
||||||
|
const cap = capacityBytes[mode] || 0;
|
||||||
|
|
||||||
|
// Update char percent to use real capacity
|
||||||
|
if (isText) {
|
||||||
|
const charPercent = document.getElementById('charPercent');
|
||||||
|
if (charPercent) {
|
||||||
|
const effectiveCap = cap > 0 ? cap : 250000;
|
||||||
|
charPercent.textContent = Math.round((payloadSize / effectiveCap) * 100) + '%';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset badge colors
|
||||||
|
const badgeMap = {
|
||||||
|
dct: 'dctCapacityBadge',
|
||||||
|
lsb: 'lsbCapacityBadge',
|
||||||
|
audio_lsb: 'lsbAudioCapacityBadge',
|
||||||
|
audio_spread: 'spreadCapacityBadge'
|
||||||
|
};
|
||||||
|
|
||||||
|
// Restore default badge colors
|
||||||
|
const dctBadge = document.getElementById('dctCapacityBadge');
|
||||||
|
const lsbBadge = document.getElementById('lsbCapacityBadge');
|
||||||
|
const audioLsbBadge = document.getElementById('lsbAudioCapacityBadge');
|
||||||
|
const spreadBadge = document.getElementById('spreadCapacityBadge');
|
||||||
|
if (dctBadge) { dctBadge.classList.remove('bg-danger'); dctBadge.classList.add('bg-warning'); }
|
||||||
|
if (lsbBadge) { lsbBadge.classList.remove('bg-danger'); lsbBadge.classList.add('bg-primary'); }
|
||||||
|
if (audioLsbBadge) { audioLsbBadge.classList.remove('bg-danger'); audioLsbBadge.classList.add('bg-primary'); }
|
||||||
|
if (spreadBadge) { spreadBadge.classList.remove('bg-danger'); spreadBadge.classList.add('bg-warning'); }
|
||||||
|
|
||||||
|
// No carrier or no payload — clear warning
|
||||||
|
if (cap === 0 || payloadSize === 0) {
|
||||||
|
warning.classList.add('d-none');
|
||||||
|
encodeBtn.disabled = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payloadSize > cap) {
|
||||||
|
// Exceeds capacity — show warning, turn badge red, disable button
|
||||||
|
const activeBadge = document.getElementById(badgeMap[mode]);
|
||||||
|
if (activeBadge) {
|
||||||
|
activeBadge.classList.remove('bg-primary', 'bg-warning');
|
||||||
|
activeBadge.classList.add('bg-danger');
|
||||||
|
}
|
||||||
|
const needed = (payloadSize / 1024).toFixed(1);
|
||||||
|
const available = (cap / 1024).toFixed(1);
|
||||||
|
warningText.textContent = `Payload too large: ${needed} KB needed, only ${available} KB capacity in ${mode.replace('_', ' ').toUpperCase()} mode`;
|
||||||
|
warning.classList.remove('d-none');
|
||||||
|
encodeBtn.disabled = true;
|
||||||
|
} else {
|
||||||
|
warning.classList.add('d-none');
|
||||||
|
encodeBtn.disabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
carrierTypeRadios.forEach(radio => {
|
carrierTypeRadios.forEach(radio => {
|
||||||
radio.addEventListener('change', function() {
|
radio.addEventListener('change', function() {
|
||||||
const isAudio = this.value === 'audio';
|
const isAudio = this.value === 'audio';
|
||||||
@@ -560,9 +641,17 @@ carrierTypeRadios.forEach(radio => {
|
|||||||
if (imageModeGroup) imageModeGroup.classList.toggle('d-none', isAudio);
|
if (imageModeGroup) imageModeGroup.classList.toggle('d-none', isAudio);
|
||||||
if (audioModeGroup) audioModeGroup.classList.toggle('d-none', !isAudio);
|
if (audioModeGroup) audioModeGroup.classList.toggle('d-none', !isAudio);
|
||||||
|
|
||||||
// Toggle capacity panels
|
// Toggle capacity panels and reset capacity values
|
||||||
if (capacityPanel) capacityPanel.classList.add('d-none');
|
if (capacityPanel) capacityPanel.classList.add('d-none');
|
||||||
if (audioCapacityPanel) audioCapacityPanel.classList.add('d-none');
|
if (audioCapacityPanel) audioCapacityPanel.classList.add('d-none');
|
||||||
|
if (isAudio) {
|
||||||
|
capacityBytes.dct = 0;
|
||||||
|
capacityBytes.lsb = 0;
|
||||||
|
} else {
|
||||||
|
capacityBytes.audio_lsb = 0;
|
||||||
|
capacityBytes.audio_spread = 0;
|
||||||
|
}
|
||||||
|
checkCapacity();
|
||||||
|
|
||||||
// Select default mode for the active type and update hint
|
// Select default mode for the active type and update hint
|
||||||
if (isAudio) {
|
if (isAudio) {
|
||||||
@@ -621,7 +710,10 @@ audioCarrierInput?.addEventListener('change', function() {
|
|||||||
document.getElementById('audioInfo').textContent = info;
|
document.getElementById('audioInfo').textContent = info;
|
||||||
document.getElementById('lsbAudioCapacityBadge').textContent = `LSB: ${(data.lsb_capacity / 1024).toFixed(1)} KB`;
|
document.getElementById('lsbAudioCapacityBadge').textContent = `LSB: ${(data.lsb_capacity / 1024).toFixed(1)} KB`;
|
||||||
document.getElementById('spreadCapacityBadge').textContent = `Spread: ${(data.spread_capacity / 1024).toFixed(1)} KB`;
|
document.getElementById('spreadCapacityBadge').textContent = `Spread: ${(data.spread_capacity / 1024).toFixed(1)} KB`;
|
||||||
|
capacityBytes.audio_lsb = data.lsb_capacity;
|
||||||
|
capacityBytes.audio_spread = data.spread_capacity;
|
||||||
document.getElementById('audioCapacityPanel')?.classList.remove('d-none');
|
document.getElementById('audioCapacityPanel')?.classList.remove('d-none');
|
||||||
|
checkCapacity();
|
||||||
if (data.duration) {
|
if (data.duration) {
|
||||||
document.getElementById('audioCarrierDuration').textContent = data.duration + 's duration';
|
document.getElementById('audioCarrierDuration').textContent = data.duration + 's duration';
|
||||||
}
|
}
|
||||||
@@ -763,6 +855,7 @@ function updatePayloadSection() {
|
|||||||
payloadFileInput.setAttribute('required', '');
|
payloadFileInput.setAttribute('required', '');
|
||||||
}
|
}
|
||||||
updatePayloadSummary();
|
updatePayloadSummary();
|
||||||
|
checkCapacity();
|
||||||
}
|
}
|
||||||
|
|
||||||
payloadTextRadio?.addEventListener('change', updatePayloadSection);
|
payloadTextRadio?.addEventListener('change', updatePayloadSection);
|
||||||
@@ -786,6 +879,7 @@ payloadFileInput?.addEventListener('change', function() {
|
|||||||
} else {
|
} else {
|
||||||
fileInfo?.classList.add('d-none');
|
fileInfo?.classList.add('d-none');
|
||||||
}
|
}
|
||||||
|
checkCapacity();
|
||||||
});
|
});
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@@ -795,7 +889,7 @@ payloadFileInput?.addEventListener('change', function() {
|
|||||||
messageInput?.addEventListener('input', function() {
|
messageInput?.addEventListener('input', function() {
|
||||||
const count = this.value.length;
|
const count = this.value.length;
|
||||||
document.getElementById('charCount').textContent = count.toLocaleString();
|
document.getElementById('charCount').textContent = count.toLocaleString();
|
||||||
document.getElementById('charPercent').textContent = Math.round((count / 250000) * 100) + '%';
|
checkCapacity();
|
||||||
});
|
});
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@@ -814,7 +908,10 @@ carrierInput?.addEventListener('change', function() {
|
|||||||
document.getElementById('carrierDimensions').textContent = `${data.width} x ${data.height}`;
|
document.getElementById('carrierDimensions').textContent = `${data.width} x ${data.height}`;
|
||||||
document.getElementById('lsbCapacityBadge').textContent = `LSB: ${data.lsb.capacity_kb} KB`;
|
document.getElementById('lsbCapacityBadge').textContent = `LSB: ${data.lsb.capacity_kb} KB`;
|
||||||
document.getElementById('dctCapacityBadge').textContent = `DCT: ${data.dct.capacity_kb} KB`;
|
document.getElementById('dctCapacityBadge').textContent = `DCT: ${data.dct.capacity_kb} KB`;
|
||||||
|
capacityBytes.lsb = Math.round(data.lsb.capacity_kb * 1024);
|
||||||
|
capacityBytes.dct = Math.round(data.dct.capacity_kb * 1024);
|
||||||
document.getElementById('capacityPanel')?.classList.remove('d-none');
|
document.getElementById('capacityPanel')?.classList.remove('d-none');
|
||||||
|
checkCapacity();
|
||||||
}).catch(() => {});
|
}).catch(() => {});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -859,7 +956,7 @@ function updateOutputOptions(mode) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modeRadios.forEach(radio => {
|
modeRadios.forEach(radio => {
|
||||||
radio.addEventListener('change', () => updateOutputOptions(radio.value));
|
radio.addEventListener('change', () => { updateOutputOptions(radio.value); checkCapacity(); });
|
||||||
});
|
});
|
||||||
|
|
||||||
// Initialize output options based on initial mode
|
// Initialize output options based on initial mode
|
||||||
|
|||||||
@@ -264,18 +264,26 @@ def embed_in_audio_lsb(
|
|||||||
debug.validate(len(sample_key) == 32, f"Sample key must be 32 bytes, got {len(sample_key)}")
|
debug.validate(len(sample_key) == 32, f"Sample key must be 32 bytes, got {len(sample_key)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Read carrier audio
|
# 1. Read carrier audio as float64 (handles all subtypes correctly)
|
||||||
samples, samplerate = sf.read(io.BytesIO(carrier_audio), dtype="int16", always_2d=True)
|
buf = io.BytesIO(carrier_audio)
|
||||||
# samples shape: (num_frames, channels)
|
float_samples, samplerate = sf.read(buf, dtype="float64", always_2d=True)
|
||||||
original_shape = samples.shape
|
original_shape = float_samples.shape
|
||||||
channels = original_shape[1]
|
channels = original_shape[1]
|
||||||
duration = original_shape[0] / samplerate
|
duration = original_shape[0] / samplerate
|
||||||
|
|
||||||
|
# Detect original subtype for output
|
||||||
|
buf.seek(0)
|
||||||
|
carrier_info = sf.info(buf)
|
||||||
|
output_subtype = carrier_info.subtype or "PCM_16"
|
||||||
|
|
||||||
debug.print(
|
debug.print(
|
||||||
f"Carrier audio: {samplerate} Hz, {channels} ch, "
|
f"Carrier audio: {samplerate} Hz, {channels} ch, "
|
||||||
f"{original_shape[0]} frames, {duration:.2f}s"
|
f"{original_shape[0]} frames, {duration:.2f}s, subtype={output_subtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert float64 → int16 for LSB manipulation (32768 matches libsndfile normalization)
|
||||||
|
samples = np.clip(float_samples * 32768.0, -32768, 32767).astype(np.int16)
|
||||||
|
|
||||||
# Flatten to 1D for embedding
|
# Flatten to 1D for embedding
|
||||||
flat_samples = samples.flatten().copy()
|
flat_samples = samples.flatten().copy()
|
||||||
num_samples = len(flat_samples)
|
num_samples = len(flat_samples)
|
||||||
@@ -347,7 +355,9 @@ def embed_in_audio_lsb(
|
|||||||
|
|
||||||
debug.print(f"Modified {modified_count} samples (out of {samples_needed} selected)")
|
debug.print(f"Modified {modified_count} samples (out of {samples_needed} selected)")
|
||||||
|
|
||||||
# 7. Reshape and write back as WAV
|
# 7. Reshape and write back as PCM_16 WAV
|
||||||
|
# LSB steganography requires integer samples — writing as FLOAT/DOUBLE
|
||||||
|
# destroys LSBs due to float32 precision loss (33k/65k values fail round-trip).
|
||||||
stego_samples = flat_samples.reshape(original_shape)
|
stego_samples = flat_samples.reshape(original_shape)
|
||||||
|
|
||||||
output_buf = io.BytesIO()
|
output_buf = io.BytesIO()
|
||||||
@@ -412,7 +422,7 @@ def extract_from_audio_lsb(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Read audio
|
# 1. Read audio as int16 directly (stego output is always PCM_16)
|
||||||
samples, samplerate = sf.read(io.BytesIO(audio_data), dtype="int16", always_2d=True)
|
samples, samplerate = sf.read(io.BytesIO(audio_data), dtype="int16", always_2d=True)
|
||||||
flat_samples = samples.flatten()
|
flat_samples = samples.flatten()
|
||||||
num_samples = len(flat_samples)
|
num_samples = len(flat_samples)
|
||||||
|
|||||||
@@ -401,6 +401,12 @@ ALLOWED_AUDIO_EXTENSIONS = {"wav", "flac", "mp3", "ogg", "opus", "aac", "m4a", "
|
|||||||
# Spread spectrum parameters
|
# Spread spectrum parameters
|
||||||
AUDIO_SS_CHIP_LENGTH = 1024 # Samples per chip (spreading factor) — legacy/default
|
AUDIO_SS_CHIP_LENGTH = 1024 # Samples per chip (spreading factor) — legacy/default
|
||||||
AUDIO_SS_AMPLITUDE = 0.05 # Per-sample embedding strength (~-26dB, masked by audio)
|
AUDIO_SS_AMPLITUDE = 0.05 # Per-sample embedding strength (~-26dB, masked by audio)
|
||||||
|
|
||||||
|
# Adaptive amplitude: embed at a fixed ratio below the carrier's RMS level.
|
||||||
|
# Keeps noise inaudible under content while ensuring extraction reliability.
|
||||||
|
AUDIO_SS_AMPLITUDE_RATIO = 0.25 # Fraction of carrier RMS (≈ -12 dB below signal)
|
||||||
|
AUDIO_SS_AMPLITUDE_MIN = 0.001 # Floor: ensures correlation ≥ 0.256 at chip=256
|
||||||
|
AUDIO_SS_AMPLITUDE_MAX = 0.05 # Ceiling: never exceed original fixed amplitude
|
||||||
AUDIO_SS_RS_NSYM = 32 # Reed-Solomon parity symbols
|
AUDIO_SS_RS_NSYM = 32 # Reed-Solomon parity symbols
|
||||||
|
|
||||||
# Spread spectrum v2: per-channel hybrid embedding (v4.4.0)
|
# Spread spectrum v2: per-channel hybrid embedding (v4.4.0)
|
||||||
@@ -424,6 +430,14 @@ AUDIO_SS_CHIP_TIER_NAMES = {
|
|||||||
AUDIO_LFE_CHANNEL_INDEX = 3 # Standard WAV/WAVEFORMATEXTENSIBLE ordering
|
AUDIO_LFE_CHANNEL_INDEX = 3 # Standard WAV/WAVEFORMATEXTENSIBLE ordering
|
||||||
AUDIO_LFE_MIN_CHANNELS = 6 # Only skip LFE for 5.1+ layouts
|
AUDIO_LFE_MIN_CHANNELS = 6 # Only skip LFE for 5.1+ layouts
|
||||||
|
|
||||||
|
# Compact padding for audio encryption (limited carrier capacity)
|
||||||
|
AUDIO_PAD_MIN = 8 # Minimum random padding bytes (vs 64 for images)
|
||||||
|
AUDIO_PAD_RANGE = 32 # Random padding range (vs 256 for images)
|
||||||
|
AUDIO_PAD_ALIGN = 32 # Alignment boundary (vs 256 for images)
|
||||||
|
|
||||||
|
# Lossless audio formats (safe for low chip tier)
|
||||||
|
LOSSLESS_AUDIO_FORMATS = {"wav", "flac", "aiff"}
|
||||||
|
|
||||||
# Echo hiding parameters
|
# Echo hiding parameters
|
||||||
AUDIO_ECHO_DELAY_0 = 50 # Echo delay for bit 0 (samples at 44.1kHz ~ 1.1ms)
|
AUDIO_ECHO_DELAY_0 = 50 # Echo delay for bit 0 (samples at 44.1kHz ~ 1.1ms)
|
||||||
AUDIO_ECHO_DELAY_1 = 100 # Echo delay for bit 1 (samples at 44.1kHz ~ 2.3ms)
|
AUDIO_ECHO_DELAY_1 = 100 # Echo delay for bit 1 (samples at 44.1kHz ~ 2.3ms)
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ from .constants import (
|
|||||||
ARGON2_MEMORY_COST,
|
ARGON2_MEMORY_COST,
|
||||||
ARGON2_PARALLELISM,
|
ARGON2_PARALLELISM,
|
||||||
ARGON2_TIME_COST,
|
ARGON2_TIME_COST,
|
||||||
|
AUDIO_PAD_ALIGN,
|
||||||
|
AUDIO_PAD_MIN,
|
||||||
|
AUDIO_PAD_RANGE,
|
||||||
FORMAT_VERSION,
|
FORMAT_VERSION,
|
||||||
FORMAT_VERSION_LEGACY,
|
FORMAT_VERSION_LEGACY,
|
||||||
HKDF_INFO_ENCRYPT,
|
HKDF_INFO_ENCRYPT,
|
||||||
@@ -463,6 +466,7 @@ def encrypt_message(
|
|||||||
pin: str = "",
|
pin: str = "",
|
||||||
rsa_key_data: bytes | None = None,
|
rsa_key_data: bytes | None = None,
|
||||||
channel_key: str | bool | None = None,
|
channel_key: str | bool | None = None,
|
||||||
|
compact: bool = False,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
Encrypt message or file using AES-256-GCM.
|
Encrypt message or file using AES-256-GCM.
|
||||||
@@ -527,8 +531,17 @@ def encrypt_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Random padding to hide message length
|
# Random padding to hide message length
|
||||||
padding_len = secrets.randbelow(256) + 64
|
# Compact mode uses smaller padding for capacity-limited carriers (audio)
|
||||||
padded_len = ((len(packed_payload) + padding_len + 255) // 256) * 256
|
if compact:
|
||||||
|
pad_min = AUDIO_PAD_MIN
|
||||||
|
pad_range = AUDIO_PAD_RANGE
|
||||||
|
pad_align = AUDIO_PAD_ALIGN
|
||||||
|
else:
|
||||||
|
pad_min = 64
|
||||||
|
pad_range = 256
|
||||||
|
pad_align = 256
|
||||||
|
padding_len = secrets.randbelow(pad_range) + pad_min
|
||||||
|
padded_len = ((len(packed_payload) + padding_len + pad_align - 1) // pad_align) * pad_align
|
||||||
padding_needed = padded_len - len(packed_payload)
|
padding_needed = padded_len - len(packed_payload)
|
||||||
padding = secrets.token_bytes(padding_needed - 4) + struct.pack(">I", len(packed_payload))
|
padding = secrets.token_bytes(padding_needed - 4) + struct.pack(">I", len(packed_payload))
|
||||||
padded_message = packed_payload + padding
|
padded_message = packed_payload + padding
|
||||||
|
|||||||
@@ -354,9 +354,10 @@ def encode_audio(
|
|||||||
debug.print(f"Transcoding {audio_format} to WAV for embedding")
|
debug.print(f"Transcoding {audio_format} to WAV for embedding")
|
||||||
carrier_audio = transcode_to_wav(carrier_audio)
|
carrier_audio = transcode_to_wav(carrier_audio)
|
||||||
|
|
||||||
# Encrypt message
|
# Encrypt message (compact padding for audio's limited capacity)
|
||||||
encrypted = encrypt_message(
|
encrypted = encrypt_message(
|
||||||
message, reference_photo, passphrase, pin, rsa_key_data, channel_key
|
message, reference_photo, passphrase, pin, rsa_key_data, channel_key,
|
||||||
|
compact=True,
|
||||||
)
|
)
|
||||||
debug.print(f"Encrypted payload: {len(encrypted)} bytes")
|
debug.print(f"Encrypted payload: {len(encrypted)} bytes")
|
||||||
|
|
||||||
@@ -371,10 +372,21 @@ def encode_audio(
|
|||||||
encrypted, carrier_audio, pixel_key, progress_file=progress_file
|
encrypted, carrier_audio, pixel_key, progress_file=progress_file
|
||||||
)
|
)
|
||||||
elif embed_mode == EMBED_MODE_AUDIO_SPREAD:
|
elif embed_mode == EMBED_MODE_AUDIO_SPREAD:
|
||||||
from .constants import AUDIO_SS_DEFAULT_CHIP_TIER
|
from .constants import (
|
||||||
|
AUDIO_SS_CHIP_TIER_LOSSLESS,
|
||||||
|
AUDIO_SS_DEFAULT_CHIP_TIER,
|
||||||
|
LOSSLESS_AUDIO_FORMATS,
|
||||||
|
)
|
||||||
from .spread_steganography import embed_in_audio_spread
|
from .spread_steganography import embed_in_audio_spread
|
||||||
|
|
||||||
tier = chip_tier if chip_tier is not None else AUDIO_SS_DEFAULT_CHIP_TIER
|
if chip_tier is not None:
|
||||||
|
tier = chip_tier
|
||||||
|
elif audio_format in LOSSLESS_AUDIO_FORMATS:
|
||||||
|
tier = AUDIO_SS_CHIP_TIER_LOSSLESS
|
||||||
|
debug.print(f"Auto-selected chip tier 0 (lossless) for {audio_format} carrier")
|
||||||
|
else:
|
||||||
|
tier = AUDIO_SS_DEFAULT_CHIP_TIER
|
||||||
|
debug.print(f"Auto-selected chip tier {tier} (lossy) for {audio_format} carrier")
|
||||||
stego_audio, stats = embed_in_audio_spread(
|
stego_audio, stats = embed_in_audio_spread(
|
||||||
encrypted, carrier_audio, pixel_key, chip_tier=tier, progress_file=progress_file
|
encrypted, carrier_audio, pixel_key, chip_tier=tier, progress_file=progress_file
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ from .constants import (
|
|||||||
AUDIO_LFE_MIN_CHANNELS,
|
AUDIO_LFE_MIN_CHANNELS,
|
||||||
AUDIO_MAGIC_SPREAD,
|
AUDIO_MAGIC_SPREAD,
|
||||||
AUDIO_SS_AMPLITUDE,
|
AUDIO_SS_AMPLITUDE,
|
||||||
|
AUDIO_SS_AMPLITUDE_MAX,
|
||||||
|
AUDIO_SS_AMPLITUDE_MIN,
|
||||||
|
AUDIO_SS_AMPLITUDE_RATIO,
|
||||||
AUDIO_SS_CHIP_LENGTH,
|
AUDIO_SS_CHIP_LENGTH,
|
||||||
AUDIO_SS_CHIP_LENGTHS,
|
AUDIO_SS_CHIP_LENGTHS,
|
||||||
AUDIO_SS_DEFAULT_CHIP_TIER,
|
AUDIO_SS_DEFAULT_CHIP_TIER,
|
||||||
@@ -81,6 +84,21 @@ except ImportError:
|
|||||||
ReedSolomonError = None # type: ignore[assignment,misc]
|
ReedSolomonError = None # type: ignore[assignment,misc]
|
||||||
|
|
||||||
|
|
||||||
|
def _adaptive_amplitude(samples: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
Compute embedding amplitude scaled to the carrier's signal level.
|
||||||
|
|
||||||
|
Uses AUDIO_SS_AMPLITUDE_RATIO of the carrier RMS, clamped between
|
||||||
|
AUDIO_SS_AMPLITUDE_MIN and AUDIO_SS_AMPLITUDE_MAX. This keeps the
|
||||||
|
embedded noise inaudible for quiet recordings while preserving the
|
||||||
|
original fixed amplitude for loud carriers.
|
||||||
|
"""
|
||||||
|
rms = float(np.sqrt(np.mean(samples**2)))
|
||||||
|
adaptive = rms * AUDIO_SS_AMPLITUDE_RATIO
|
||||||
|
amplitude = max(AUDIO_SS_AMPLITUDE_MIN, min(AUDIO_SS_AMPLITUDE_MAX, adaptive))
|
||||||
|
return amplitude
|
||||||
|
|
||||||
|
|
||||||
# Header sizes
|
# Header sizes
|
||||||
_V0_HEADER_SIZE = 16 # Legacy: 4B magic + 3x4B length
|
_V0_HEADER_SIZE = 16 # Legacy: 4B magic + 3x4B length
|
||||||
_V2_HEADER_SIZE = 20 # v2: 4B magic + 1B ver + 1B tier + 1B nch + 1B flags + 3x4B length
|
_V2_HEADER_SIZE = 20 # v2: 4B magic + 1B ver + 1B tier + 1B nch + 1B flags + 3x4B length
|
||||||
@@ -687,9 +705,13 @@ def embed_in_audio_spread(
|
|||||||
lfe_skipped = len(embed_ch) < channels
|
lfe_skipped = len(embed_ch) < channels
|
||||||
chip_length = AUDIO_SS_CHIP_LENGTHS.get(chip_tier, AUDIO_SS_CHIP_LENGTH)
|
chip_length = AUDIO_SS_CHIP_LENGTHS.get(chip_tier, AUDIO_SS_CHIP_LENGTH)
|
||||||
|
|
||||||
|
# Compute adaptive amplitude scaled to carrier signal level
|
||||||
|
amplitude = _adaptive_amplitude(samples)
|
||||||
|
|
||||||
debug.print(
|
debug.print(
|
||||||
f"Carrier: {sample_rate} Hz, {channels} ch ({num_embed_ch} embeddable), "
|
f"Carrier: {sample_rate} Hz, {channels} ch ({num_embed_ch} embeddable), "
|
||||||
f"{num_frames} frames, {duration:.2f}s, chip={chip_length}"
|
f"{num_frames} frames, {duration:.2f}s, chip={chip_length}, "
|
||||||
|
f"amplitude={amplitude:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. RS-encode the payload
|
# 3. RS-encode the payload
|
||||||
@@ -709,7 +731,7 @@ def embed_in_audio_spread(
|
|||||||
samples[:, embed_ch[0]],
|
samples[:, embed_ch[0]],
|
||||||
header_bits,
|
header_bits,
|
||||||
seed,
|
seed,
|
||||||
AUDIO_SS_AMPLITUDE,
|
amplitude,
|
||||||
_HEADER_CHIP_LENGTH,
|
_HEADER_CHIP_LENGTH,
|
||||||
channel_index=0,
|
channel_index=0,
|
||||||
offset=0,
|
offset=0,
|
||||||
@@ -752,7 +774,7 @@ def embed_in_audio_spread(
|
|||||||
samples[:, ch],
|
samples[:, ch],
|
||||||
bits_for_ch,
|
bits_for_ch,
|
||||||
seed,
|
seed,
|
||||||
AUDIO_SS_AMPLITUDE,
|
amplitude,
|
||||||
chip_length,
|
chip_length,
|
||||||
channel_index=ch,
|
channel_index=ch,
|
||||||
offset=payload_offset,
|
offset=payload_offset,
|
||||||
|
|||||||
BIN
test_data/stupid_elitist_speech.wav
Normal file
BIN
test_data/stupid_elitist_speech.wav
Normal file
Binary file not shown.
Reference in New Issue
Block a user