Minor fixes
This commit is contained in:
30
agentstuff/pyproject.toml
Normal file
30
agentstuff/pyproject.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[project]
|
||||
name = "sentiment-agent"
|
||||
version = "0.1.0"
|
||||
description = "AI agent for gathering data and performing sentiment analysis"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"claude-agent-sdk",
|
||||
"anyio",
|
||||
"httpx",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"ruff",
|
||||
"mypy",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
sentiment-agent = "sentiment_agent.main:main"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
ignore_missing_imports = true
|
||||
3
agentstuff/sentiment_agent/__init__.py
Normal file
3
agentstuff/sentiment_agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Sentiment analysis agent powered by Claude Agent SDK."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
115
agentstuff/sentiment_agent/agent.py
Normal file
115
agentstuff/sentiment_agent/agent.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Core sentiment analysis agent using Claude Agent SDK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeAgentOptions,
|
||||
ClaudeSDKClient,
|
||||
ResultMessage,
|
||||
TextBlock,
|
||||
)
|
||||
|
||||
from sentiment_agent.config import SafetyConfig
|
||||
from sentiment_agent.tools import create_social_tools_server
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are a sentiment analysis agent. Your job is to gather data from multiple \
|
||||
platforms and produce a structured, evidence-based sentiment report.
|
||||
|
||||
## Rules — you MUST follow these
|
||||
|
||||
1. **Budget awareness.** You have a limited API call budget. Call \
|
||||
`get_api_budget_status` before starting and after every few tool calls. \
|
||||
Stop gathering data when you have <5 calls remaining and begin your analysis.
|
||||
|
||||
2. **Credibility first.** Every tool result includes credibility scores and \
|
||||
bot/disinfo flags. You MUST:
|
||||
- NEVER quote or cite posts marked `likely_inauthentic` (score < 0.3).
|
||||
- Flag posts marked `suspicious` (score 0.3–0.5) with a warning when citing them.
|
||||
- Give more weight to `likely_authentic` posts (score ≥ 0.7).
|
||||
- If coordination warnings appear (copy-paste campaigns, burst posting), \
|
||||
call them out prominently in your report.
|
||||
|
||||
3. **Platform diversity.** Gather from at least 2 different platforms before \
|
||||
analyzing. Do not over-index on a single source.
|
||||
|
||||
4. **No fabrication.** Only report on data you actually retrieved. If a tool \
|
||||
call fails or returns no results, say so — do not invent data.
|
||||
|
||||
5. **Structured output.** Your final report MUST include these sections:
|
||||
- **Data Quality Summary**: platforms queried, posts analyzed vs excluded, \
|
||||
coordination warnings
|
||||
- **Overall Sentiment**: score (-1.0 to +1.0) and label \
|
||||
(very negative / negative / mixed / neutral / positive / very positive)
|
||||
- **Platform Breakdown**: sentiment per platform with sample size
|
||||
- **Key Themes**: top 3-5 themes with sentiment direction
|
||||
- **Credibility Concerns**: any bot networks, disinfo patterns, or \
|
||||
coordinated campaigns detected
|
||||
- **Notable Quotes**: 3-5 representative quotes (authentic sources only, \
|
||||
with credibility score noted)
|
||||
- **Confidence Assessment**: how confident you are in the analysis given \
|
||||
data quality and volume
|
||||
|
||||
6. **Scope discipline.** Stay focused on the requested topic. Do not expand \
|
||||
scope, follow tangents, or analyze adjacent topics unless explicitly asked.
|
||||
|
||||
7. **No side effects.** Do not write files, run commands, or take any action \
|
||||
beyond reading data and producing your report.
|
||||
"""
|
||||
|
||||
|
||||
async def run_sentiment_analysis(
|
||||
topic: str,
|
||||
sources: list[str] | None = None,
|
||||
config: SafetyConfig | None = None,
|
||||
) -> str:
|
||||
"""Run the sentiment analysis agent on a given topic.
|
||||
|
||||
Args:
|
||||
topic: The topic or subject to analyze sentiment for.
|
||||
sources: Optional list of URLs or data sources to analyze.
|
||||
config: Safety configuration. Defaults to SafetyConfig.from_env().
|
||||
|
||||
Returns:
|
||||
The agent's sentiment analysis report.
|
||||
"""
|
||||
config = config or SafetyConfig.from_env()
|
||||
|
||||
source_instructions = ""
|
||||
if sources:
|
||||
source_list = "\n".join(f"- {s}" for s in sources)
|
||||
source_instructions = f"\n\nAlso analyze these specific sources:\n{source_list}"
|
||||
|
||||
prompt = (
|
||||
f"Perform a sentiment analysis on the following topic: {topic}\n\n"
|
||||
"Start by calling `get_api_budget_status` to check your budget, then "
|
||||
"gather data from multiple platforms (Reddit, Hacker News, Bluesky if "
|
||||
"configured, and web search). Pay close attention to credibility scores "
|
||||
"and coordination warnings in the results."
|
||||
f"{source_instructions}"
|
||||
)
|
||||
|
||||
social_server = create_social_tools_server(config)
|
||||
|
||||
options = ClaudeAgentOptions(
|
||||
# Only allow read-only tools — no Write/Bash to prevent side effects
|
||||
allowed_tools=["WebSearch", "WebFetch", "Read"],
|
||||
max_turns=config.max_turns,
|
||||
max_budget_usd=config.max_budget_usd,
|
||||
mcp_servers={"social": social_server},
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
await client.query(prompt)
|
||||
async for message in client.receive_response():
|
||||
if isinstance(message, AssistantMessage):
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(block.text, end="", flush=True)
|
||||
if isinstance(message, ResultMessage):
|
||||
result_text = message.result
|
||||
|
||||
return result_text
|
||||
1
agentstuff/sentiment_agent/clients/__init__.py
Normal file
1
agentstuff/sentiment_agent/clients/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API clients for social media and forum data sources."""
|
||||
166
agentstuff/sentiment_agent/clients/bluesky.py
Normal file
166
agentstuff/sentiment_agent/clients/bluesky.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Bluesky client using the AT Protocol API.
|
||||
|
||||
Search requires authentication. Set BLUESKY_HANDLE and BLUESKY_APP_PASSWORD
|
||||
env vars. Create an app password at: https://bsky.app/settings/app-passwords
|
||||
|
||||
Thread fetching works without auth via the public API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
|
||||
BSKY_PUBLIC_API = "https://public.api.bsky.app"
|
||||
BSKY_AUTH_API = "https://bsky.social"
|
||||
|
||||
|
||||
async def _get_session() -> dict | None:
|
||||
"""Authenticate with Bluesky and return session tokens, or None if no creds."""
|
||||
handle = os.environ.get("BLUESKY_HANDLE")
|
||||
app_password = os.environ.get("BLUESKY_APP_PASSWORD")
|
||||
if not handle or not app_password:
|
||||
return None
|
||||
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.post(
|
||||
f"{BSKY_AUTH_API}/xrpc/com.atproto.server.createSession",
|
||||
json={"identifier": handle, "password": app_password},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _format_post(post_view: dict) -> dict:
|
||||
"""Extract relevant fields from an AT Protocol post view."""
|
||||
post = post_view.get("post", post_view)
|
||||
record = post.get("record", {})
|
||||
author = post.get("author", {})
|
||||
return {
|
||||
"text": record.get("text", ""),
|
||||
"author_handle": author.get("handle", ""),
|
||||
"author_display_name": author.get("displayName", ""),
|
||||
"created_at": record.get("createdAt", ""),
|
||||
"like_count": post.get("likeCount", 0),
|
||||
"repost_count": post.get("repostCount", 0),
|
||||
"reply_count": post.get("replyCount", 0),
|
||||
"uri": post.get("uri", ""),
|
||||
"cid": post.get("cid", ""),
|
||||
"url": _uri_to_url(post.get("uri", ""), author.get("handle", "")),
|
||||
}
|
||||
|
||||
|
||||
def _uri_to_url(uri: str, handle: str) -> str:
|
||||
"""Convert an at:// URI to a bsky.app URL."""
|
||||
# at://did:plc:xxx/app.bsky.feed.post/rkey -> https://bsky.app/profile/handle/post/rkey
|
||||
if not uri.startswith("at://"):
|
||||
return ""
|
||||
parts = uri.split("/")
|
||||
if len(parts) >= 5:
|
||||
rkey = parts[-1]
|
||||
return f"https://bsky.app/profile/{handle}/post/{rkey}"
|
||||
return ""
|
||||
|
||||
|
||||
async def search_posts(query: str, limit: int = 25, sort: str = "top") -> list[dict]:
|
||||
"""Search Bluesky for posts matching a query.
|
||||
|
||||
Requires BLUESKY_HANDLE and BLUESKY_APP_PASSWORD env vars.
|
||||
|
||||
Args:
|
||||
query: Search terms.
|
||||
limit: Max results (capped at 100).
|
||||
sort: "top" (most liked) or "latest" (chronological).
|
||||
|
||||
Returns:
|
||||
List of post dicts with: text, author_handle, author_display_name,
|
||||
created_at, like_count, repost_count, reply_count, uri, url.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Bluesky credentials are not configured.
|
||||
"""
|
||||
session = await _get_session()
|
||||
if not session:
|
||||
raise RuntimeError(
|
||||
"Bluesky search requires authentication. "
|
||||
"Set BLUESKY_HANDLE and BLUESKY_APP_PASSWORD environment variables. "
|
||||
"Create an app password at: https://bsky.app/settings/app-passwords"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(
|
||||
f"{BSKY_AUTH_API}/xrpc/app.bsky.feed.searchPosts",
|
||||
params={
|
||||
"q": query,
|
||||
"limit": min(limit, 100),
|
||||
"sort": sort,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {session['accessJwt']}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return [_format_post(p) for p in data.get("posts", [])]
|
||||
|
||||
|
||||
async def get_thread(uri: str, depth: int = 6) -> dict:
|
||||
"""Fetch a Bluesky thread by AT URI or bsky.app URL.
|
||||
|
||||
Args:
|
||||
uri: Either an at:// URI or a https://bsky.app/profile/.../post/... URL.
|
||||
depth: How many levels of replies to fetch (max 1000).
|
||||
|
||||
Returns:
|
||||
Dict with "post" (the root post) and "replies" (list of reply post dicts).
|
||||
"""
|
||||
# Convert bsky.app URL to AT URI if needed
|
||||
if uri.startswith("https://bsky.app/"):
|
||||
uri = await _resolve_url_to_uri(uri)
|
||||
|
||||
headers = {}
|
||||
session = await _get_session()
|
||||
if session:
|
||||
headers["Authorization"] = f"Bearer {session['accessJwt']}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(
|
||||
f"{BSKY_PUBLIC_API}/xrpc/app.bsky.feed.getPostThread",
|
||||
params={"uri": uri, "depth": min(depth, 1000)},
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
thread = data.get("thread", {})
|
||||
root_post = _format_post(thread) if "post" in thread else {}
|
||||
|
||||
replies = []
|
||||
for reply in thread.get("replies", []):
|
||||
if "post" in reply:
|
||||
replies.append(_format_post(reply))
|
||||
# Include nested replies one level deep
|
||||
for nested in reply.get("replies", []):
|
||||
if "post" in nested:
|
||||
replies.append(_format_post(nested))
|
||||
|
||||
return {"post": root_post, "replies": replies}
|
||||
|
||||
|
||||
async def _resolve_url_to_uri(url: str) -> str:
|
||||
"""Convert a bsky.app URL to an AT URI by resolving the handle."""
|
||||
# https://bsky.app/profile/handle.bsky.social/post/rkey
|
||||
parts = url.rstrip("/").split("/")
|
||||
if len(parts) < 6:
|
||||
raise ValueError(f"Invalid Bluesky URL: {url}")
|
||||
|
||||
handle = parts[4] # profile/{handle}
|
||||
rkey = parts[6] # post/{rkey}
|
||||
|
||||
# Resolve handle to DID
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(
|
||||
f"{BSKY_PUBLIC_API}/xrpc/com.atproto.identity.resolveHandle",
|
||||
params={"handle": handle},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
did = resp.json()["did"]
|
||||
|
||||
return f"at://{did}/app.bsky.feed.post/{rkey}"
|
||||
78
agentstuff/sentiment_agent/clients/hackernews.py
Normal file
78
agentstuff/sentiment_agent/clients/hackernews.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Hacker News client using the Algolia HN Search API.
|
||||
|
||||
No authentication required. Docs: https://hn.algolia.com/api
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
HN_API_BASE = "https://hn.algolia.com/api/v1"
|
||||
|
||||
|
||||
async def search_stories(query: str, limit: int = 25) -> list[dict]:
|
||||
"""Search HN for stories matching a query.
|
||||
|
||||
Returns a list of story dicts with: title, url, author, points,
|
||||
num_comments, created_at, objectID, story_text.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(
|
||||
f"{HN_API_BASE}/search",
|
||||
params={
|
||||
"query": query,
|
||||
"tags": "story",
|
||||
"hitsPerPage": min(limit, 50),
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = []
|
||||
for hit in data.get("hits", []):
|
||||
results.append(
|
||||
{
|
||||
"title": hit.get("title", ""),
|
||||
"url": hit.get("url", ""),
|
||||
"author": hit.get("author", ""),
|
||||
"points": hit.get("points", 0),
|
||||
"num_comments": hit.get("num_comments", 0),
|
||||
"created_at": hit.get("created_at", ""),
|
||||
"object_id": hit.get("objectID", ""),
|
||||
"story_text": hit.get("story_text") or "",
|
||||
"hn_url": f"https://news.ycombinator.com/item?id={hit.get('objectID', '')}",
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def search_comments(query: str, limit: int = 25) -> list[dict]:
|
||||
"""Search HN for comments matching a query.
|
||||
|
||||
Returns a list of comment dicts with: comment_text, author, points,
|
||||
created_at, story_title, story_url.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(
|
||||
f"{HN_API_BASE}/search",
|
||||
params={
|
||||
"query": query,
|
||||
"tags": "comment",
|
||||
"hitsPerPage": min(limit, 50),
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = []
|
||||
for hit in data.get("hits", []):
|
||||
results.append(
|
||||
{
|
||||
"comment_text": hit.get("comment_text", ""),
|
||||
"author": hit.get("author", ""),
|
||||
"points": hit.get("points", 0),
|
||||
"created_at": hit.get("created_at", ""),
|
||||
"story_title": hit.get("story_title", ""),
|
||||
"story_url": hit.get("story_url", ""),
|
||||
"hn_url": f"https://news.ycombinator.com/item?id={hit.get('objectID', '')}",
|
||||
}
|
||||
)
|
||||
return results
|
||||
117
agentstuff/sentiment_agent/clients/reddit.py
Normal file
117
agentstuff/sentiment_agent/clients/reddit.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Reddit client using the public JSON API.
|
||||
|
||||
No authentication required for read-only search. Reddit requires a descriptive
|
||||
User-Agent header — requests with generic UAs get 429'd.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
REDDIT_BASE = "https://www.reddit.com"
|
||||
USER_AGENT = "sentiment-agent/0.1.0 (research; sentiment analysis tool)"
|
||||
|
||||
|
||||
async def search_posts(
|
||||
query: str,
|
||||
subreddit: str = "all",
|
||||
sort: str = "relevance",
|
||||
time_filter: str = "month",
|
||||
limit: int = 25,
|
||||
) -> list[dict]:
|
||||
"""Search Reddit for posts matching a query.
|
||||
|
||||
Args:
|
||||
query: Search terms.
|
||||
subreddit: Subreddit to search within, or "all" for site-wide.
|
||||
sort: One of "relevance", "hot", "top", "new", "comments".
|
||||
time_filter: One of "hour", "day", "week", "month", "year", "all".
|
||||
limit: Max results (capped at 100 by Reddit).
|
||||
|
||||
Returns:
|
||||
List of post dicts with: title, selftext, author, score,
|
||||
num_comments, subreddit, url, permalink, created_utc.
|
||||
"""
|
||||
url = f"{REDDIT_BASE}/r/{subreddit}/search.json"
|
||||
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
||||
resp = await client.get(
|
||||
url,
|
||||
params={
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"t": time_filter,
|
||||
"limit": min(limit, 100),
|
||||
"restrict_sr": "on" if subreddit != "all" else "off",
|
||||
},
|
||||
headers={"User-Agent": USER_AGENT},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = []
|
||||
for child in data.get("data", {}).get("children", []):
|
||||
post = child.get("data", {})
|
||||
results.append(
|
||||
{
|
||||
"title": post.get("title", ""),
|
||||
"selftext": (post.get("selftext") or "")[:2000],
|
||||
"author": post.get("author", "[deleted]"),
|
||||
"score": post.get("score", 0),
|
||||
"upvote_ratio": post.get("upvote_ratio", 0),
|
||||
"num_comments": post.get("num_comments", 0),
|
||||
"subreddit": post.get("subreddit", ""),
|
||||
"url": post.get("url", ""),
|
||||
"permalink": f"https://reddit.com{post.get('permalink', '')}",
|
||||
"created_utc": post.get("created_utc", 0),
|
||||
"is_self": post.get("is_self", False),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def get_post_comments(
|
||||
permalink: str,
|
||||
sort: str = "top",
|
||||
limit: int = 25,
|
||||
) -> list[dict]:
|
||||
"""Fetch top-level comments for a Reddit post.
|
||||
|
||||
Args:
|
||||
permalink: The post's permalink path (e.g., "/r/python/comments/abc123/title/").
|
||||
sort: Comment sort order: "top", "best", "new", "controversial".
|
||||
limit: Max comments to return.
|
||||
|
||||
Returns:
|
||||
List of comment dicts with: body, author, score, created_utc.
|
||||
"""
|
||||
# Strip domain if full URL was passed
|
||||
if permalink.startswith("https://"):
|
||||
permalink = permalink.replace("https://reddit.com", "")
|
||||
permalink = permalink.replace("https://www.reddit.com", "")
|
||||
|
||||
url = f"{REDDIT_BASE}{permalink}.json"
|
||||
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
||||
resp = await client.get(
|
||||
url,
|
||||
params={"sort": sort, "limit": limit},
|
||||
headers={"User-Agent": USER_AGENT},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Reddit returns [post_listing, comments_listing]
|
||||
if not isinstance(data, list) or len(data) < 2:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for child in data[1].get("data", {}).get("children", []):
|
||||
if child.get("kind") != "t1":
|
||||
continue
|
||||
comment = child.get("data", {})
|
||||
results.append(
|
||||
{
|
||||
"body": (comment.get("body") or "")[:2000],
|
||||
"author": comment.get("author", "[deleted]"),
|
||||
"score": comment.get("score", 0),
|
||||
"created_utc": comment.get("created_utc", 0),
|
||||
}
|
||||
)
|
||||
return results
|
||||
70
agentstuff/sentiment_agent/config.py
Normal file
70
agentstuff/sentiment_agent/config.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Configuration and safety limits for the sentiment agent.
|
||||
|
||||
All guardrails are centralized here so they can be tuned from one place
|
||||
or overridden via CLI flags / env vars.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Per-platform rate limiting."""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
burst_size: int = 3 # max concurrent requests
|
||||
cooldown_after_429: float = 30.0 # seconds to wait after a 429
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SafetyConfig:
|
||||
"""Top-level safety rails for the agent."""
|
||||
|
||||
# --- Agent-level limits ---
|
||||
max_turns: int = 20
|
||||
max_budget_usd: float = 0.50 # hard cap on Claude API spend per run
|
||||
max_total_api_calls: int = 50 # across ALL platforms combined
|
||||
max_results_per_call: int = 50 # cap the `limit` param sent to any API
|
||||
|
||||
# --- Per-platform rate limits ---
|
||||
bluesky_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig(
|
||||
requests_per_minute=10, burst_size=2,
|
||||
))
|
||||
reddit_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig(
|
||||
requests_per_minute=10, burst_size=2,
|
||||
))
|
||||
hackernews_rate: RateLimitConfig = field(default_factory=lambda: RateLimitConfig(
|
||||
requests_per_minute=15, burst_size=3, # HN Algolia is more generous
|
||||
))
|
||||
|
||||
# --- Content size limits ---
|
||||
max_post_text_chars: int = 2000 # truncate individual posts beyond this
|
||||
max_total_content_bytes: int = 500_000 # ~500KB total data gathered before agent stops
|
||||
|
||||
# --- Timeout ---
|
||||
api_timeout_seconds: float = 15.0
|
||||
|
||||
# --- Credibility thresholds ---
|
||||
min_credibility_score: float = 0.3 # posts below this are flagged/excluded
|
||||
flag_bot_threshold: float = 0.5 # posts between min and this are flagged but included
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> SafetyConfig:
|
||||
"""Build config with env var overrides.
|
||||
|
||||
Env vars: SENTIMENT_MAX_TURNS, SENTIMENT_MAX_BUDGET_USD,
|
||||
SENTIMENT_MAX_API_CALLS, SENTIMENT_MIN_CREDIBILITY.
|
||||
"""
|
||||
kwargs: dict = {}
|
||||
if v := os.environ.get("SENTIMENT_MAX_TURNS"):
|
||||
kwargs["max_turns"] = int(v)
|
||||
if v := os.environ.get("SENTIMENT_MAX_BUDGET_USD"):
|
||||
kwargs["max_budget_usd"] = float(v)
|
||||
if v := os.environ.get("SENTIMENT_MAX_API_CALLS"):
|
||||
kwargs["max_total_api_calls"] = int(v)
|
||||
if v := os.environ.get("SENTIMENT_MIN_CREDIBILITY"):
|
||||
kwargs["min_credibility_score"] = float(v)
|
||||
return cls(**kwargs)
|
||||
398
agentstuff/sentiment_agent/credibility.py
Normal file
398
agentstuff/sentiment_agent/credibility.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""Credibility scoring and bot/disinfo detection.
|
||||
|
||||
Assigns a 0.0–1.0 credibility score to each post based on heuristic signals.
|
||||
Posts below the configured threshold are excluded or flagged so they don't
|
||||
pollute the sentiment analysis.
|
||||
|
||||
Signals are platform-aware — each platform has different indicators of
|
||||
inauthentic behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
@dataclass
|
||||
class CredibilityResult:
|
||||
"""Credibility assessment for a single post."""
|
||||
|
||||
score: float # 0.0 (likely bot/disinfo) to 1.0 (likely authentic)
|
||||
flags: list[str] = field(default_factory=list) # human-readable reasons
|
||||
is_excluded: bool = False # below min_credibility_score
|
||||
is_flagged: bool = False # between min and flag threshold
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
if self.score >= 0.7:
|
||||
return "likely_authentic"
|
||||
if self.score >= 0.5:
|
||||
return "uncertain"
|
||||
if self.score >= 0.3:
|
||||
return "suspicious"
|
||||
return "likely_inauthentic"
|
||||
|
||||
|
||||
# --- Shared heuristics ---
|
||||
|
||||
# Common bot patterns in text
|
||||
_BOT_TEXT_PATTERNS = [
|
||||
# Crypto/scam spam
|
||||
re.compile(r"(?i)(dm me|check my bio|link in bio|click here|free giveaway)"),
|
||||
re.compile(r"(?i)(join my|subscribe to|follow me for|🔥.*🔥.*🔥)"),
|
||||
# Astroturfing phrases
|
||||
re.compile(r"(?i)(i (just )?(discovered|found|tried) this (amazing|incredible|awesome))"),
|
||||
re.compile(r"(?i)(game.?changer|life.?changing|you won'?t believe)"),
|
||||
# Excessive hashtags (5+)
|
||||
re.compile(r"(#\w+\s*){5,}"),
|
||||
# Walls of emojis (10+ consecutive)
|
||||
re.compile(r"[\U0001F300-\U0001FAFF]{10,}"),
|
||||
# Repetitive characters (spammy emphasis)
|
||||
re.compile(r"(.)\1{9,}"),
|
||||
]
|
||||
|
||||
# Coordinated campaign indicators: identical or near-identical text
|
||||
# This is checked at the batch level, not per-post
|
||||
|
||||
|
||||
def _check_text_patterns(text: str) -> list[str]:
|
||||
"""Check text against common bot/spam patterns."""
|
||||
flags = []
|
||||
for pattern in _BOT_TEXT_PATTERNS:
|
||||
if pattern.search(text):
|
||||
flags.append(f"bot_text_pattern: {pattern.pattern[:60]}")
|
||||
if len(text) < 15:
|
||||
flags.append("very_short_text")
|
||||
return flags
|
||||
|
||||
|
||||
def _engagement_ratio_score(
|
||||
likes: int, reposts: int, replies: int
|
||||
) -> tuple[float, list[str]]:
|
||||
"""Score based on engagement ratios.
|
||||
|
||||
Authentic posts tend to have a mix of likes, replies, and reposts.
|
||||
Bot-amplified posts often have inflated likes with very few replies,
|
||||
or massive repost counts with no discussion.
|
||||
"""
|
||||
flags = []
|
||||
total = likes + reposts + replies
|
||||
|
||||
if total == 0:
|
||||
return 0.5, ["no_engagement"]
|
||||
|
||||
# High repost-to-reply ratio suggests amplification without discussion
|
||||
if reposts > 0 and replies == 0 and reposts > 10:
|
||||
flags.append(f"high_repost_no_replies: {reposts} reposts, 0 replies")
|
||||
return 0.3, flags
|
||||
|
||||
# Extremely high like count with zero replies is suspicious
|
||||
if likes > 100 and replies == 0:
|
||||
flags.append(f"high_likes_no_replies: {likes} likes, 0 replies")
|
||||
return 0.4, flags
|
||||
|
||||
# Normal engagement
|
||||
return min(1.0, 0.5 + (replies / max(total, 1)) * 0.5), flags
|
||||
|
||||
|
||||
# --- Platform-specific scoring ---
|
||||
|
||||
|
||||
def score_bluesky_post(post: dict) -> CredibilityResult:
|
||||
"""Score a Bluesky post for credibility."""
|
||||
score = 1.0
|
||||
flags: list[str] = []
|
||||
|
||||
text = post.get("text", "")
|
||||
handle = post.get("author_handle", "")
|
||||
display_name = post.get("author_display_name", "")
|
||||
likes = post.get("like_count", 0)
|
||||
reposts = post.get("repost_count", 0)
|
||||
replies = post.get("reply_count", 0)
|
||||
|
||||
# Text pattern checks
|
||||
text_flags = _check_text_patterns(text)
|
||||
if text_flags:
|
||||
score -= 0.15 * len(text_flags)
|
||||
flags.extend(text_flags)
|
||||
|
||||
# Handle heuristics
|
||||
# Randomly generated handles (long hex/number strings)
|
||||
if re.match(r"^[a-f0-9]{8,}\.", handle):
|
||||
flags.append(f"random_handle: {handle}")
|
||||
score -= 0.3
|
||||
|
||||
# No display name set
|
||||
if not display_name or display_name == handle:
|
||||
flags.append("no_display_name")
|
||||
score -= 0.1
|
||||
|
||||
# Engagement ratio
|
||||
eng_score, eng_flags = _engagement_ratio_score(likes, reposts, replies)
|
||||
flags.extend(eng_flags)
|
||||
score = score * 0.6 + eng_score * 0.4
|
||||
|
||||
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||
|
||||
|
||||
def score_reddit_post(post: dict) -> CredibilityResult:
|
||||
"""Score a Reddit post for credibility."""
|
||||
score = 1.0
|
||||
flags: list[str] = []
|
||||
|
||||
text = post.get("selftext", "") or post.get("title", "")
|
||||
author = post.get("author", "")
|
||||
upvote_ratio = post.get("upvote_ratio", 0.5)
|
||||
post_score = post.get("score", 0)
|
||||
num_comments = post.get("num_comments", 0)
|
||||
|
||||
# Text patterns
|
||||
text_flags = _check_text_patterns(text)
|
||||
if text_flags:
|
||||
score -= 0.15 * len(text_flags)
|
||||
flags.extend(text_flags)
|
||||
|
||||
# Deleted author
|
||||
if author in ("[deleted]", "[removed]"):
|
||||
flags.append("deleted_author")
|
||||
score -= 0.2
|
||||
|
||||
# Suspicious username patterns (random alphanumeric + numbers)
|
||||
if re.match(r"^[A-Za-z]+[-_]?\d{4,}$", author):
|
||||
flags.append(f"auto_generated_username: {author}")
|
||||
score -= 0.15
|
||||
|
||||
# Very controversial ratio (lots of up AND down votes)
|
||||
if upvote_ratio < 0.4 and post_score > 0:
|
||||
flags.append(f"highly_controversial: {upvote_ratio:.0%} upvote ratio")
|
||||
score -= 0.1
|
||||
|
||||
# High score but zero comments = potential vote manipulation
|
||||
if post_score > 100 and num_comments == 0:
|
||||
flags.append(f"high_score_no_comments: {post_score} score, 0 comments")
|
||||
score -= 0.2
|
||||
|
||||
# Low-effort cross-post spam: very short title, external link, no selftext
|
||||
if (
|
||||
len(post.get("title", "")) < 20
|
||||
and not post.get("is_self", True)
|
||||
and not post.get("selftext")
|
||||
):
|
||||
flags.append("possible_link_spam")
|
||||
score -= 0.1
|
||||
|
||||
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||
|
||||
|
||||
def score_reddit_comment(comment: dict) -> CredibilityResult:
|
||||
"""Score a Reddit comment for credibility."""
|
||||
score = 1.0
|
||||
flags: list[str] = []
|
||||
|
||||
body = comment.get("body", "")
|
||||
author = comment.get("author", "")
|
||||
comment_score = comment.get("score", 0)
|
||||
|
||||
text_flags = _check_text_patterns(body)
|
||||
if text_flags:
|
||||
score -= 0.15 * len(text_flags)
|
||||
flags.extend(text_flags)
|
||||
|
||||
if author in ("[deleted]", "[removed]"):
|
||||
flags.append("deleted_author")
|
||||
score -= 0.2
|
||||
|
||||
if re.match(r"^[A-Za-z]+[-_]?\d{4,}$", author):
|
||||
flags.append(f"auto_generated_username: {author}")
|
||||
score -= 0.15
|
||||
|
||||
# Heavily downvoted
|
||||
if comment_score < -5:
|
||||
flags.append(f"heavily_downvoted: {comment_score}")
|
||||
score -= 0.15
|
||||
|
||||
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||
|
||||
|
||||
def score_hackernews_post(post: dict) -> CredibilityResult:
|
||||
"""Score a HN story for credibility.
|
||||
|
||||
HN is generally higher-signal than social media, but we still check
|
||||
for low-effort submissions and spammy patterns.
|
||||
"""
|
||||
score = 1.0
|
||||
flags: list[str] = []
|
||||
|
||||
title = post.get("title", "")
|
||||
text = post.get("story_text", "") or title
|
||||
points = post.get("points", 0)
|
||||
num_comments = post.get("num_comments", 0)
|
||||
|
||||
text_flags = _check_text_patterns(text)
|
||||
if text_flags:
|
||||
score -= 0.1 * len(text_flags)
|
||||
flags.extend(text_flags)
|
||||
|
||||
# Zero points = the community didn't find it valuable
|
||||
if points == 0:
|
||||
flags.append("zero_points")
|
||||
score -= 0.1
|
||||
|
||||
# HN is generally more credible, start with a bonus
|
||||
score = min(1.0, score + 0.1)
|
||||
|
||||
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||
|
||||
|
||||
def score_hackernews_comment(comment: dict) -> CredibilityResult:
|
||||
"""Score a HN comment for credibility."""
|
||||
score = 1.0
|
||||
flags: list[str] = []
|
||||
|
||||
text = comment.get("comment_text", "")
|
||||
|
||||
text_flags = _check_text_patterns(text)
|
||||
if text_flags:
|
||||
score -= 0.1 * len(text_flags)
|
||||
flags.extend(text_flags)
|
||||
|
||||
# HN comments are generally higher quality
|
||||
score = min(1.0, score + 0.1)
|
||||
|
||||
return CredibilityResult(score=max(0.0, min(1.0, score)), flags=flags)
|
||||
|
||||
|
||||
# --- Batch-level coordination detection ---
|
||||
|
||||
|
||||
def detect_coordination(posts: list[dict], text_key: str = "text") -> list[str]:
|
||||
"""Detect coordinated inauthentic behavior across a batch of posts.
|
||||
|
||||
Looks for:
|
||||
- Duplicate or near-duplicate text (copy-paste campaigns)
|
||||
- Burst posting (many posts in a very short window)
|
||||
- Same talking points with minor variations
|
||||
|
||||
Returns a list of warning strings.
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
texts = [p.get(text_key, "") for p in posts if p.get(text_key)]
|
||||
|
||||
if not texts:
|
||||
return warnings
|
||||
|
||||
# Exact duplicates
|
||||
seen: dict[str, int] = {}
|
||||
for t in texts:
|
||||
normalized = t.strip().lower()
|
||||
seen[normalized] = seen.get(normalized, 0) + 1
|
||||
|
||||
duplicates = {text: count for text, count in seen.items() if count > 1}
|
||||
if duplicates:
|
||||
total_dupes = sum(duplicates.values())
|
||||
warnings.append(
|
||||
f"COORDINATION WARNING: {len(duplicates)} duplicate texts found "
|
||||
f"({total_dupes} total copies). Possible copy-paste campaign."
|
||||
)
|
||||
|
||||
# Near-duplicates: check if many posts share a long common substring
|
||||
# (simplified: check if >30% of posts start with the same 50+ chars)
|
||||
if len(texts) >= 5:
|
||||
prefixes: dict[str, int] = {}
|
||||
for t in texts:
|
||||
prefix = t.strip().lower()[:80]
|
||||
if len(prefix) >= 50:
|
||||
prefixes[prefix] = prefixes.get(prefix, 0) + 1
|
||||
|
||||
for prefix, count in prefixes.items():
|
||||
if count >= len(texts) * 0.3:
|
||||
warnings.append(
|
||||
f"COORDINATION WARNING: {count}/{len(texts)} posts share "
|
||||
f"a common prefix ({prefix[:50]}...). Possible template campaign."
|
||||
)
|
||||
|
||||
# Burst detection: if timestamps are available
|
||||
timestamps = []
|
||||
for p in posts:
|
||||
created = p.get("created_at") or p.get("created_utc")
|
||||
if isinstance(created, str):
|
||||
try:
|
||||
timestamps.append(datetime.fromisoformat(created.replace("Z", "+00:00")))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif isinstance(created, (int, float)):
|
||||
timestamps.append(datetime.fromtimestamp(created, tz=timezone.utc))
|
||||
|
||||
if len(timestamps) >= 5:
|
||||
timestamps.sort()
|
||||
# Check if >50% of posts landed within a 5-minute window
|
||||
window_seconds = 300
|
||||
for i in range(len(timestamps) - 2):
|
||||
window_end = timestamps[i] + __import__("datetime").timedelta(seconds=window_seconds)
|
||||
in_window = sum(1 for t in timestamps if timestamps[i] <= t <= window_end)
|
||||
if in_window >= len(timestamps) * 0.5:
|
||||
warnings.append(
|
||||
f"COORDINATION WARNING: {in_window}/{len(timestamps)} posts "
|
||||
f"appeared within a 5-minute window. Possible coordinated posting."
|
||||
)
|
||||
break
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
def filter_and_annotate(
|
||||
posts: list[dict],
|
||||
scorer,
|
||||
min_score: float = 0.3,
|
||||
flag_threshold: float = 0.5,
|
||||
) -> tuple[list[dict], dict]:
|
||||
"""Score all posts, filter out low-credibility ones, and annotate the rest.
|
||||
|
||||
Args:
|
||||
posts: List of post dicts from any platform.
|
||||
scorer: A scoring function (e.g., score_reddit_post).
|
||||
min_score: Posts below this are excluded.
|
||||
flag_threshold: Posts between min_score and this are flagged.
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_posts, stats_dict).
|
||||
Each post in filtered_posts gets a "_credibility" key added.
|
||||
"""
|
||||
filtered = []
|
||||
stats = {
|
||||
"total": len(posts),
|
||||
"excluded": 0,
|
||||
"flagged": 0,
|
||||
"authentic": 0,
|
||||
"excluded_reasons": [],
|
||||
}
|
||||
|
||||
for post in posts:
|
||||
result = scorer(post)
|
||||
result.is_excluded = result.score < min_score
|
||||
result.is_flagged = min_score <= result.score < flag_threshold
|
||||
|
||||
if result.is_excluded:
|
||||
stats["excluded"] += 1
|
||||
stats["excluded_reasons"].append(
|
||||
{"score": round(result.score, 2), "flags": result.flags}
|
||||
)
|
||||
continue
|
||||
|
||||
post["_credibility"] = {
|
||||
"score": round(result.score, 2),
|
||||
"label": result.label,
|
||||
"flags": result.flags,
|
||||
"is_flagged": result.is_flagged,
|
||||
}
|
||||
|
||||
if result.is_flagged:
|
||||
stats["flagged"] += 1
|
||||
else:
|
||||
stats["authentic"] += 1
|
||||
|
||||
filtered.append(post)
|
||||
|
||||
return filtered, stats
|
||||
66
agentstuff/sentiment_agent/main.py
Normal file
66
agentstuff/sentiment_agent/main.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""CLI entry point for the sentiment analysis agent."""
|
||||
|
||||
import argparse
|
||||
import anyio
|
||||
|
||||
from sentiment_agent.agent import run_sentiment_analysis
|
||||
from sentiment_agent.config import SafetyConfig
|
||||
|
||||
|
||||
async def async_main(args: argparse.Namespace) -> None:
|
||||
config = SafetyConfig(
|
||||
max_turns=args.max_turns,
|
||||
max_budget_usd=args.max_budget,
|
||||
max_total_api_calls=args.max_api_calls,
|
||||
min_credibility_score=args.min_credibility,
|
||||
flag_bot_threshold=args.flag_threshold,
|
||||
)
|
||||
|
||||
result = await run_sentiment_analysis(
|
||||
topic=args.topic,
|
||||
sources=args.sources,
|
||||
config=config,
|
||||
)
|
||||
print("\n" + "=" * 60)
|
||||
print("SENTIMENT ANALYSIS REPORT")
|
||||
print("=" * 60)
|
||||
print(result)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run sentiment analysis on a topic with bot/disinfo detection",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("topic", help="The topic to analyze sentiment for")
|
||||
parser.add_argument(
|
||||
"--sources", nargs="*", help="Specific URLs or sources to also analyze"
|
||||
)
|
||||
|
||||
safety = parser.add_argument_group("safety limits")
|
||||
safety.add_argument(
|
||||
"--max-turns", type=int, default=20, help="Max agent turns"
|
||||
)
|
||||
safety.add_argument(
|
||||
"--max-budget", type=float, default=0.50, help="Max Claude API spend (USD)"
|
||||
)
|
||||
safety.add_argument(
|
||||
"--max-api-calls", type=int, default=50, help="Max total API calls across all platforms"
|
||||
)
|
||||
|
||||
credibility = parser.add_argument_group("credibility filtering")
|
||||
credibility.add_argument(
|
||||
"--min-credibility", type=float, default=0.3,
|
||||
help="Posts below this score are excluded (0.0-1.0)",
|
||||
)
|
||||
credibility.add_argument(
|
||||
"--flag-threshold", type=float, default=0.5,
|
||||
help="Posts between min and this are flagged but included (0.0-1.0)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
anyio.run(async_main, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
169
agentstuff/sentiment_agent/ratelimit.py
Normal file
169
agentstuff/sentiment_agent/ratelimit.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Rate limiter and API call budget tracker.
|
||||
|
||||
Enforces per-platform rate limits and a global call budget so the agent
|
||||
can't hammer APIs or run up unbounded costs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sentiment_agent.config import RateLimitConfig
|
||||
|
||||
|
||||
class BudgetExhaustedError(Exception):
|
||||
"""Raised when the global API call budget is spent."""
|
||||
|
||||
|
||||
class RateLimitExceededError(Exception):
|
||||
"""Raised when a platform's rate limit is hit and cooldown hasn't elapsed."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PlatformState:
|
||||
"""Tracks call timestamps and active request count for one platform."""
|
||||
|
||||
config: RateLimitConfig
|
||||
call_timestamps: list[float] = field(default_factory=list)
|
||||
active_requests: int = 0
|
||||
last_429_at: float = 0.0
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Manages rate limiting across all platforms + a global call budget.
|
||||
|
||||
Usage:
|
||||
limiter = RateLimiter(max_total_calls=50)
|
||||
limiter.register_platform("reddit", RateLimitConfig(...))
|
||||
|
||||
async with limiter.acquire("reddit"):
|
||||
await do_reddit_call()
|
||||
"""
|
||||
|
||||
def __init__(self, max_total_calls: int = 50):
|
||||
self._max_total = max_total_calls
|
||||
self._total_calls = 0
|
||||
self._platforms: dict[str, _PlatformState] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def total_calls(self) -> int:
|
||||
return self._total_calls
|
||||
|
||||
@property
|
||||
def remaining_calls(self) -> int:
|
||||
return max(0, self._max_total - self._total_calls)
|
||||
|
||||
def register_platform(self, name: str, config: RateLimitConfig) -> None:
|
||||
self._platforms[name] = _PlatformState(config=config)
|
||||
|
||||
def acquire(self, platform: str) -> _AcquireContext:
|
||||
"""Context manager that enforces rate limits before allowing a call."""
|
||||
return _AcquireContext(self, platform)
|
||||
|
||||
async def _acquire(self, platform: str) -> None:
|
||||
async with self._lock:
|
||||
if self._total_calls >= self._max_total:
|
||||
raise BudgetExhaustedError(
|
||||
f"Global API call budget exhausted ({self._max_total} calls). "
|
||||
"Increase max_total_api_calls in SafetyConfig to allow more."
|
||||
)
|
||||
|
||||
state = self._platforms.get(platform)
|
||||
if not state:
|
||||
raise ValueError(f"Platform '{platform}' not registered with rate limiter")
|
||||
|
||||
now = time.monotonic()
|
||||
|
||||
# Check 429 cooldown
|
||||
if state.last_429_at:
|
||||
elapsed = now - state.last_429_at
|
||||
if elapsed < state.config.cooldown_after_429:
|
||||
remaining = state.config.cooldown_after_429 - elapsed
|
||||
raise RateLimitExceededError(
|
||||
f"Platform '{platform}' is in cooldown after 429. "
|
||||
f"Try again in {remaining:.0f}s."
|
||||
)
|
||||
state.last_429_at = 0.0
|
||||
|
||||
# Check burst limit
|
||||
if state.active_requests >= state.config.burst_size:
|
||||
raise RateLimitExceededError(
|
||||
f"Platform '{platform}' burst limit reached "
|
||||
f"({state.config.burst_size} concurrent). Wait for a request to finish."
|
||||
)
|
||||
|
||||
# Check RPM: discard timestamps older than 60s, then check count
|
||||
cutoff = now - 60.0
|
||||
state.call_timestamps = [t for t in state.call_timestamps if t > cutoff]
|
||||
|
||||
if len(state.call_timestamps) >= state.config.requests_per_minute:
|
||||
oldest = state.call_timestamps[0]
|
||||
wait_time = 60.0 - (now - oldest)
|
||||
raise RateLimitExceededError(
|
||||
f"Platform '{platform}' rate limit: {state.config.requests_per_minute}/min. "
|
||||
f"Try again in {wait_time:.0f}s."
|
||||
)
|
||||
|
||||
# All clear — record the call
|
||||
state.call_timestamps.append(now)
|
||||
state.active_requests += 1
|
||||
self._total_calls += 1
|
||||
|
||||
async def _release(self, platform: str) -> None:
|
||||
async with self._lock:
|
||||
state = self._platforms.get(platform)
|
||||
if state:
|
||||
state.active_requests = max(0, state.active_requests - 1)
|
||||
|
||||
def record_429(self, platform: str) -> None:
|
||||
"""Call this when an API returns 429 to trigger cooldown."""
|
||||
state = self._platforms.get(platform)
|
||||
if state:
|
||||
state.last_429_at = time.monotonic()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Return current usage stats for logging/reporting."""
|
||||
stats: dict = {
|
||||
"total_calls": self._total_calls,
|
||||
"remaining_calls": self.remaining_calls,
|
||||
"platforms": {},
|
||||
}
|
||||
for name, state in self._platforms.items():
|
||||
now = time.monotonic()
|
||||
cutoff = now - 60.0
|
||||
recent = [t for t in state.call_timestamps if t > cutoff]
|
||||
stats["platforms"][name] = {
|
||||
"calls_last_60s": len(recent),
|
||||
"active_requests": state.active_requests,
|
||||
"rpm_limit": state.config.requests_per_minute,
|
||||
"in_cooldown": bool(
|
||||
state.last_429_at
|
||||
and (now - state.last_429_at) < state.config.cooldown_after_429
|
||||
),
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
class _AcquireContext:
|
||||
"""Async context manager for rate-limited API calls."""
|
||||
|
||||
def __init__(self, limiter: RateLimiter, platform: str):
|
||||
self._limiter = limiter
|
||||
self._platform = platform
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self._limiter._acquire(self._platform)
|
||||
|
||||
async def __aexit__(self, *exc_info) -> None:
|
||||
# Check if the call got a 429
|
||||
if exc_info[0] is not None:
|
||||
import httpx
|
||||
|
||||
exc = exc_info[1]
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429:
|
||||
self._limiter.record_429(self._platform)
|
||||
|
||||
await self._limiter._release(self._platform)
|
||||
352
agentstuff/sentiment_agent/tools.py
Normal file
352
agentstuff/sentiment_agent/tools.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Custom MCP tools for social media and forum data gathering.
|
||||
|
||||
Each tool wraps an API client, enforces rate limits, runs credibility
|
||||
scoring, and returns MCP-formatted results with bot/disinfo annotations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from claude_agent_sdk import tool, create_sdk_mcp_server
|
||||
|
||||
from sentiment_agent.clients import bluesky, reddit, hackernews
|
||||
from sentiment_agent.config import SafetyConfig
|
||||
from sentiment_agent.credibility import (
|
||||
detect_coordination,
|
||||
filter_and_annotate,
|
||||
score_bluesky_post,
|
||||
score_hackernews_comment,
|
||||
score_hackernews_post,
|
||||
score_reddit_comment,
|
||||
score_reddit_post,
|
||||
)
|
||||
from sentiment_agent.ratelimit import BudgetExhaustedError, RateLimiter
|
||||
|
||||
# Module-level state — initialized by create_social_tools_server()
|
||||
_limiter: RateLimiter | None = None
|
||||
_config: SafetyConfig | None = None
|
||||
|
||||
|
||||
def _get_limiter() -> RateLimiter:
|
||||
if _limiter is None:
|
||||
raise RuntimeError("Tools not initialized — call create_social_tools_server() first")
|
||||
return _limiter
|
||||
|
||||
|
||||
def _get_config() -> SafetyConfig:
|
||||
if _config is None:
|
||||
return SafetyConfig()
|
||||
return _config
|
||||
|
||||
|
||||
def _text_result(text: str) -> dict:
|
||||
return {"content": [{"type": "text", "text": text}]}
|
||||
|
||||
|
||||
def _error_result(error: str) -> dict:
|
||||
return {"content": [{"type": "text", "text": f"Error: {error}"}], "isError": True}
|
||||
|
||||
|
||||
def _clamp_limit(requested: int) -> int:
|
||||
"""Enforce max results per call."""
|
||||
return min(requested, _get_config().max_results_per_call)
|
||||
|
||||
|
||||
def _format_with_stats(
|
||||
posts: list[dict],
|
||||
stats: dict,
|
||||
coordination_warnings: list[str],
|
||||
platform: str,
|
||||
) -> str:
|
||||
"""Format results with credibility stats prepended."""
|
||||
header_parts = [
|
||||
f"Platform: {platform}",
|
||||
f"Results: {stats['authentic']} authentic, {stats['flagged']} flagged, "
|
||||
f"{stats['excluded']} excluded (of {stats['total']} total)",
|
||||
]
|
||||
if coordination_warnings:
|
||||
header_parts.append("--- COORDINATION ALERTS ---")
|
||||
header_parts.extend(coordination_warnings)
|
||||
header_parts.append("---")
|
||||
|
||||
limiter = _get_limiter()
|
||||
header_parts.append(f"API budget remaining: {limiter.remaining_calls} calls")
|
||||
|
||||
header = "\n".join(header_parts)
|
||||
body = json.dumps(posts, indent=2, default=str)
|
||||
return f"{header}\n\n{body}"
|
||||
|
||||
|
||||
# --- Bluesky tools ---
|
||||
|
||||
|
||||
@tool(
|
||||
"search_bluesky",
|
||||
"Search Bluesky for posts about a topic. Returns posts with text, author, "
|
||||
"engagement metrics, credibility scores, and bot/disinfo flags. "
|
||||
"Requires BLUESKY_HANDLE and BLUESKY_APP_PASSWORD env vars.",
|
||||
{"query": str, "limit": int, "sort": str},
|
||||
)
|
||||
async def search_bluesky(args: dict) -> dict:
|
||||
try:
|
||||
limiter = _get_limiter()
|
||||
config = _get_config()
|
||||
|
||||
async with limiter.acquire("bluesky"):
|
||||
posts = await bluesky.search_posts(
|
||||
query=args["query"],
|
||||
limit=_clamp_limit(args.get("limit", 25)),
|
||||
sort=args.get("sort", "top"),
|
||||
)
|
||||
|
||||
if not posts:
|
||||
return _text_result(f"No Bluesky posts found for: {args['query']}")
|
||||
|
||||
coordination = detect_coordination(posts, text_key="text")
|
||||
filtered, stats = filter_and_annotate(
|
||||
posts, score_bluesky_post,
|
||||
min_score=config.min_credibility_score,
|
||||
flag_threshold=config.flag_bot_threshold,
|
||||
)
|
||||
return _text_result(_format_with_stats(filtered, stats, coordination, "Bluesky"))
|
||||
except BudgetExhaustedError as e:
|
||||
return _error_result(str(e))
|
||||
except Exception as e:
|
||||
return _error_result(f"Bluesky search failed: {e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
@tool(
|
||||
"get_bluesky_thread",
|
||||
"Fetch a Bluesky thread/post and its replies with credibility scoring. "
|
||||
"Accepts an at:// URI or https://bsky.app/... URL.",
|
||||
{"uri": str, "depth": int},
|
||||
)
|
||||
async def get_bluesky_thread(args: dict) -> dict:
|
||||
try:
|
||||
limiter = _get_limiter()
|
||||
config = _get_config()
|
||||
|
||||
async with limiter.acquire("bluesky"):
|
||||
thread = await bluesky.get_thread(
|
||||
uri=args["uri"],
|
||||
depth=args.get("depth", 6),
|
||||
)
|
||||
|
||||
# Score replies
|
||||
if thread.get("replies"):
|
||||
coordination = detect_coordination(thread["replies"], text_key="text")
|
||||
filtered_replies, stats = filter_and_annotate(
|
||||
thread["replies"], score_bluesky_post,
|
||||
min_score=config.min_credibility_score,
|
||||
flag_threshold=config.flag_bot_threshold,
|
||||
)
|
||||
thread["replies"] = filtered_replies
|
||||
thread["_reply_credibility_stats"] = stats
|
||||
thread["_coordination_warnings"] = coordination
|
||||
|
||||
# Score root post
|
||||
if thread.get("post"):
|
||||
result = score_bluesky_post(thread["post"])
|
||||
thread["post"]["_credibility"] = {
|
||||
"score": round(result.score, 2),
|
||||
"label": result.label,
|
||||
"flags": result.flags,
|
||||
}
|
||||
|
||||
return _text_result(json.dumps(thread, indent=2, default=str))
|
||||
except BudgetExhaustedError as e:
|
||||
return _error_result(str(e))
|
||||
except Exception as e:
|
||||
return _error_result(f"Bluesky thread fetch failed: {e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
# --- Reddit tools ---
|
||||
|
||||
|
||||
@tool(
|
||||
"search_reddit",
|
||||
"Search Reddit for posts about a topic. Returns posts with credibility scores "
|
||||
"and bot/disinfo flags. Posts below the credibility threshold are auto-excluded. "
|
||||
"Use subreddit='all' for site-wide or specify a subreddit name.",
|
||||
{"query": str, "subreddit": str, "sort": str, "time_filter": str, "limit": int},
|
||||
)
|
||||
async def search_reddit_tool(args: dict) -> dict:
|
||||
try:
|
||||
limiter = _get_limiter()
|
||||
config = _get_config()
|
||||
|
||||
async with limiter.acquire("reddit"):
|
||||
posts = await reddit.search_posts(
|
||||
query=args["query"],
|
||||
subreddit=args.get("subreddit", "all"),
|
||||
sort=args.get("sort", "relevance"),
|
||||
time_filter=args.get("time_filter", "month"),
|
||||
limit=_clamp_limit(args.get("limit", 25)),
|
||||
)
|
||||
|
||||
if not posts:
|
||||
return _text_result(f"No Reddit posts found for: {args['query']}")
|
||||
|
||||
coordination = detect_coordination(posts, text_key="title")
|
||||
filtered, stats = filter_and_annotate(
|
||||
posts, score_reddit_post,
|
||||
min_score=config.min_credibility_score,
|
||||
flag_threshold=config.flag_bot_threshold,
|
||||
)
|
||||
return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit"))
|
||||
except BudgetExhaustedError as e:
|
||||
return _error_result(str(e))
|
||||
except Exception as e:
|
||||
return _error_result(f"Reddit search failed: {e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
@tool(
|
||||
"get_reddit_comments",
|
||||
"Fetch comments for a Reddit post with credibility scoring. "
|
||||
"Pass the permalink path or full URL.",
|
||||
{"permalink": str, "sort": str, "limit": int},
|
||||
)
|
||||
async def get_reddit_comments(args: dict) -> dict:
|
||||
try:
|
||||
limiter = _get_limiter()
|
||||
config = _get_config()
|
||||
|
||||
async with limiter.acquire("reddit"):
|
||||
comments = await reddit.get_post_comments(
|
||||
permalink=args["permalink"],
|
||||
sort=args.get("sort", "top"),
|
||||
limit=_clamp_limit(args.get("limit", 25)),
|
||||
)
|
||||
|
||||
if not comments:
|
||||
return _text_result("No comments found for this post.")
|
||||
|
||||
coordination = detect_coordination(comments, text_key="body")
|
||||
filtered, stats = filter_and_annotate(
|
||||
comments, score_reddit_comment,
|
||||
min_score=config.min_credibility_score,
|
||||
flag_threshold=config.flag_bot_threshold,
|
||||
)
|
||||
return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit Comments"))
|
||||
except BudgetExhaustedError as e:
|
||||
return _error_result(str(e))
|
||||
except Exception as e:
|
||||
return _error_result(f"Reddit comments fetch failed: {e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
# --- Hacker News tools ---
|
||||
|
||||
|
||||
@tool(
|
||||
"search_hackernews",
|
||||
"Search Hacker News for stories with credibility scoring. "
|
||||
"No authentication required.",
|
||||
{"query": str, "limit": int},
|
||||
)
|
||||
async def search_hackernews_tool(args: dict) -> dict:
|
||||
try:
|
||||
limiter = _get_limiter()
|
||||
config = _get_config()
|
||||
|
||||
async with limiter.acquire("hackernews"):
|
||||
stories = await hackernews.search_stories(
|
||||
query=args["query"],
|
||||
limit=_clamp_limit(args.get("limit", 25)),
|
||||
)
|
||||
|
||||
if not stories:
|
||||
return _text_result(f"No HN stories found for: {args['query']}")
|
||||
|
||||
coordination = detect_coordination(stories, text_key="title")
|
||||
filtered, stats = filter_and_annotate(
|
||||
stories, score_hackernews_post,
|
||||
min_score=config.min_credibility_score,
|
||||
flag_threshold=config.flag_bot_threshold,
|
||||
)
|
||||
return _text_result(_format_with_stats(filtered, stats, coordination, "Hacker News"))
|
||||
except BudgetExhaustedError as e:
|
||||
return _error_result(str(e))
|
||||
except Exception as e:
|
||||
return _error_result(f"HN search failed: {e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
@tool(
|
||||
"search_hackernews_comments",
|
||||
"Search Hacker News comments for opinions and discussions with credibility scoring.",
|
||||
{"query": str, "limit": int},
|
||||
)
|
||||
async def search_hackernews_comments(args: dict) -> dict:
|
||||
try:
|
||||
limiter = _get_limiter()
|
||||
config = _get_config()
|
||||
|
||||
async with limiter.acquire("hackernews"):
|
||||
comments = await hackernews.search_comments(
|
||||
query=args["query"],
|
||||
limit=_clamp_limit(args.get("limit", 25)),
|
||||
)
|
||||
|
||||
if not comments:
|
||||
return _text_result(f"No HN comments found for: {args['query']}")
|
||||
|
||||
coordination = detect_coordination(comments, text_key="comment_text")
|
||||
filtered, stats = filter_and_annotate(
|
||||
comments, score_hackernews_comment,
|
||||
min_score=config.min_credibility_score,
|
||||
flag_threshold=config.flag_bot_threshold,
|
||||
)
|
||||
return _text_result(
|
||||
_format_with_stats(filtered, stats, coordination, "HN Comments")
|
||||
)
|
||||
except BudgetExhaustedError as e:
|
||||
return _error_result(str(e))
|
||||
except Exception as e:
|
||||
return _error_result(f"HN comment search failed: {e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
# --- Budget status tool ---
|
||||
|
||||
|
||||
@tool(
|
||||
"get_api_budget_status",
|
||||
"Check remaining API call budget, rate limit status, and per-platform stats. "
|
||||
"Use this before making more API calls to avoid hitting limits.",
|
||||
{},
|
||||
)
|
||||
async def get_api_budget_status(args: dict) -> dict:
|
||||
limiter = _get_limiter()
|
||||
stats = limiter.get_stats()
|
||||
return _text_result(json.dumps(stats, indent=2, default=str))
|
||||
|
||||
|
||||
# --- Server factory ---
|
||||
|
||||
|
||||
def create_social_tools_server(config: SafetyConfig | None = None):
|
||||
"""Create an MCP server with all social media/forum tools.
|
||||
|
||||
Initializes rate limiting and credibility thresholds from config.
|
||||
"""
|
||||
global _limiter, _config
|
||||
|
||||
_config = config or SafetyConfig.from_env()
|
||||
|
||||
_limiter = RateLimiter(max_total_calls=_config.max_total_api_calls)
|
||||
_limiter.register_platform("bluesky", _config.bluesky_rate)
|
||||
_limiter.register_platform("reddit", _config.reddit_rate)
|
||||
_limiter.register_platform("hackernews", _config.hackernews_rate)
|
||||
|
||||
return create_sdk_mcp_server(
|
||||
"social-tools",
|
||||
tools=[
|
||||
search_bluesky,
|
||||
get_bluesky_thread,
|
||||
search_reddit_tool,
|
||||
get_reddit_comments,
|
||||
search_hackernews_tool,
|
||||
search_hackernews_comments,
|
||||
get_api_budget_status,
|
||||
],
|
||||
)
|
||||
Reference in New Issue
Block a user