353 lines
12 KiB
Python
353 lines
12 KiB
Python
"""Custom MCP tools for social media and forum data gathering.
|
|
|
|
Each tool wraps an API client, enforces rate limits, runs credibility
|
|
scoring, and returns MCP-formatted results with bot/disinfo annotations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import traceback
|
|
|
|
from claude_agent_sdk import tool, create_sdk_mcp_server
|
|
|
|
from sentiment_agent.clients import bluesky, reddit, hackernews
|
|
from sentiment_agent.config import SafetyConfig
|
|
from sentiment_agent.credibility import (
|
|
detect_coordination,
|
|
filter_and_annotate,
|
|
score_bluesky_post,
|
|
score_hackernews_comment,
|
|
score_hackernews_post,
|
|
score_reddit_comment,
|
|
score_reddit_post,
|
|
)
|
|
from sentiment_agent.ratelimit import BudgetExhaustedError, RateLimiter
|
|
|
|
# Module-level state — initialized by create_social_tools_server()
|
|
_limiter: RateLimiter | None = None
|
|
_config: SafetyConfig | None = None
|
|
|
|
|
|
def _get_limiter() -> RateLimiter:
|
|
if _limiter is None:
|
|
raise RuntimeError("Tools not initialized — call create_social_tools_server() first")
|
|
return _limiter
|
|
|
|
|
|
def _get_config() -> SafetyConfig:
|
|
if _config is None:
|
|
return SafetyConfig()
|
|
return _config
|
|
|
|
|
|
def _text_result(text: str) -> dict:
|
|
return {"content": [{"type": "text", "text": text}]}
|
|
|
|
|
|
def _error_result(error: str) -> dict:
|
|
return {"content": [{"type": "text", "text": f"Error: {error}"}], "isError": True}
|
|
|
|
|
|
def _clamp_limit(requested: int) -> int:
|
|
"""Enforce max results per call."""
|
|
return min(requested, _get_config().max_results_per_call)
|
|
|
|
|
|
def _format_with_stats(
|
|
posts: list[dict],
|
|
stats: dict,
|
|
coordination_warnings: list[str],
|
|
platform: str,
|
|
) -> str:
|
|
"""Format results with credibility stats prepended."""
|
|
header_parts = [
|
|
f"Platform: {platform}",
|
|
f"Results: {stats['authentic']} authentic, {stats['flagged']} flagged, "
|
|
f"{stats['excluded']} excluded (of {stats['total']} total)",
|
|
]
|
|
if coordination_warnings:
|
|
header_parts.append("--- COORDINATION ALERTS ---")
|
|
header_parts.extend(coordination_warnings)
|
|
header_parts.append("---")
|
|
|
|
limiter = _get_limiter()
|
|
header_parts.append(f"API budget remaining: {limiter.remaining_calls} calls")
|
|
|
|
header = "\n".join(header_parts)
|
|
body = json.dumps(posts, indent=2, default=str)
|
|
return f"{header}\n\n{body}"
|
|
|
|
|
|
# --- Bluesky tools ---
|
|
|
|
|
|
@tool(
|
|
"search_bluesky",
|
|
"Search Bluesky for posts about a topic. Returns posts with text, author, "
|
|
"engagement metrics, credibility scores, and bot/disinfo flags. "
|
|
"Requires BLUESKY_HANDLE and BLUESKY_APP_PASSWORD env vars.",
|
|
{"query": str, "limit": int, "sort": str},
|
|
)
|
|
async def search_bluesky(args: dict) -> dict:
|
|
try:
|
|
limiter = _get_limiter()
|
|
config = _get_config()
|
|
|
|
async with limiter.acquire("bluesky"):
|
|
posts = await bluesky.search_posts(
|
|
query=args["query"],
|
|
limit=_clamp_limit(args.get("limit", 25)),
|
|
sort=args.get("sort", "top"),
|
|
)
|
|
|
|
if not posts:
|
|
return _text_result(f"No Bluesky posts found for: {args['query']}")
|
|
|
|
coordination = detect_coordination(posts, text_key="text")
|
|
filtered, stats = filter_and_annotate(
|
|
posts, score_bluesky_post,
|
|
min_score=config.min_credibility_score,
|
|
flag_threshold=config.flag_bot_threshold,
|
|
)
|
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Bluesky"))
|
|
except BudgetExhaustedError as e:
|
|
return _error_result(str(e))
|
|
except Exception as e:
|
|
return _error_result(f"Bluesky search failed: {e}\n{traceback.format_exc()}")
|
|
|
|
|
|
@tool(
|
|
"get_bluesky_thread",
|
|
"Fetch a Bluesky thread/post and its replies with credibility scoring. "
|
|
"Accepts an at:// URI or https://bsky.app/... URL.",
|
|
{"uri": str, "depth": int},
|
|
)
|
|
async def get_bluesky_thread(args: dict) -> dict:
|
|
try:
|
|
limiter = _get_limiter()
|
|
config = _get_config()
|
|
|
|
async with limiter.acquire("bluesky"):
|
|
thread = await bluesky.get_thread(
|
|
uri=args["uri"],
|
|
depth=args.get("depth", 6),
|
|
)
|
|
|
|
# Score replies
|
|
if thread.get("replies"):
|
|
coordination = detect_coordination(thread["replies"], text_key="text")
|
|
filtered_replies, stats = filter_and_annotate(
|
|
thread["replies"], score_bluesky_post,
|
|
min_score=config.min_credibility_score,
|
|
flag_threshold=config.flag_bot_threshold,
|
|
)
|
|
thread["replies"] = filtered_replies
|
|
thread["_reply_credibility_stats"] = stats
|
|
thread["_coordination_warnings"] = coordination
|
|
|
|
# Score root post
|
|
if thread.get("post"):
|
|
result = score_bluesky_post(thread["post"])
|
|
thread["post"]["_credibility"] = {
|
|
"score": round(result.score, 2),
|
|
"label": result.label,
|
|
"flags": result.flags,
|
|
}
|
|
|
|
return _text_result(json.dumps(thread, indent=2, default=str))
|
|
except BudgetExhaustedError as e:
|
|
return _error_result(str(e))
|
|
except Exception as e:
|
|
return _error_result(f"Bluesky thread fetch failed: {e}\n{traceback.format_exc()}")
|
|
|
|
|
|
# --- Reddit tools ---
|
|
|
|
|
|
@tool(
|
|
"search_reddit",
|
|
"Search Reddit for posts about a topic. Returns posts with credibility scores "
|
|
"and bot/disinfo flags. Posts below the credibility threshold are auto-excluded. "
|
|
"Use subreddit='all' for site-wide or specify a subreddit name.",
|
|
{"query": str, "subreddit": str, "sort": str, "time_filter": str, "limit": int},
|
|
)
|
|
async def search_reddit_tool(args: dict) -> dict:
|
|
try:
|
|
limiter = _get_limiter()
|
|
config = _get_config()
|
|
|
|
async with limiter.acquire("reddit"):
|
|
posts = await reddit.search_posts(
|
|
query=args["query"],
|
|
subreddit=args.get("subreddit", "all"),
|
|
sort=args.get("sort", "relevance"),
|
|
time_filter=args.get("time_filter", "month"),
|
|
limit=_clamp_limit(args.get("limit", 25)),
|
|
)
|
|
|
|
if not posts:
|
|
return _text_result(f"No Reddit posts found for: {args['query']}")
|
|
|
|
coordination = detect_coordination(posts, text_key="title")
|
|
filtered, stats = filter_and_annotate(
|
|
posts, score_reddit_post,
|
|
min_score=config.min_credibility_score,
|
|
flag_threshold=config.flag_bot_threshold,
|
|
)
|
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit"))
|
|
except BudgetExhaustedError as e:
|
|
return _error_result(str(e))
|
|
except Exception as e:
|
|
return _error_result(f"Reddit search failed: {e}\n{traceback.format_exc()}")
|
|
|
|
|
|
@tool(
|
|
"get_reddit_comments",
|
|
"Fetch comments for a Reddit post with credibility scoring. "
|
|
"Pass the permalink path or full URL.",
|
|
{"permalink": str, "sort": str, "limit": int},
|
|
)
|
|
async def get_reddit_comments(args: dict) -> dict:
|
|
try:
|
|
limiter = _get_limiter()
|
|
config = _get_config()
|
|
|
|
async with limiter.acquire("reddit"):
|
|
comments = await reddit.get_post_comments(
|
|
permalink=args["permalink"],
|
|
sort=args.get("sort", "top"),
|
|
limit=_clamp_limit(args.get("limit", 25)),
|
|
)
|
|
|
|
if not comments:
|
|
return _text_result("No comments found for this post.")
|
|
|
|
coordination = detect_coordination(comments, text_key="body")
|
|
filtered, stats = filter_and_annotate(
|
|
comments, score_reddit_comment,
|
|
min_score=config.min_credibility_score,
|
|
flag_threshold=config.flag_bot_threshold,
|
|
)
|
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Reddit Comments"))
|
|
except BudgetExhaustedError as e:
|
|
return _error_result(str(e))
|
|
except Exception as e:
|
|
return _error_result(f"Reddit comments fetch failed: {e}\n{traceback.format_exc()}")
|
|
|
|
|
|
# --- Hacker News tools ---
|
|
|
|
|
|
@tool(
|
|
"search_hackernews",
|
|
"Search Hacker News for stories with credibility scoring. "
|
|
"No authentication required.",
|
|
{"query": str, "limit": int},
|
|
)
|
|
async def search_hackernews_tool(args: dict) -> dict:
|
|
try:
|
|
limiter = _get_limiter()
|
|
config = _get_config()
|
|
|
|
async with limiter.acquire("hackernews"):
|
|
stories = await hackernews.search_stories(
|
|
query=args["query"],
|
|
limit=_clamp_limit(args.get("limit", 25)),
|
|
)
|
|
|
|
if not stories:
|
|
return _text_result(f"No HN stories found for: {args['query']}")
|
|
|
|
coordination = detect_coordination(stories, text_key="title")
|
|
filtered, stats = filter_and_annotate(
|
|
stories, score_hackernews_post,
|
|
min_score=config.min_credibility_score,
|
|
flag_threshold=config.flag_bot_threshold,
|
|
)
|
|
return _text_result(_format_with_stats(filtered, stats, coordination, "Hacker News"))
|
|
except BudgetExhaustedError as e:
|
|
return _error_result(str(e))
|
|
except Exception as e:
|
|
return _error_result(f"HN search failed: {e}\n{traceback.format_exc()}")
|
|
|
|
|
|
@tool(
|
|
"search_hackernews_comments",
|
|
"Search Hacker News comments for opinions and discussions with credibility scoring.",
|
|
{"query": str, "limit": int},
|
|
)
|
|
async def search_hackernews_comments(args: dict) -> dict:
|
|
try:
|
|
limiter = _get_limiter()
|
|
config = _get_config()
|
|
|
|
async with limiter.acquire("hackernews"):
|
|
comments = await hackernews.search_comments(
|
|
query=args["query"],
|
|
limit=_clamp_limit(args.get("limit", 25)),
|
|
)
|
|
|
|
if not comments:
|
|
return _text_result(f"No HN comments found for: {args['query']}")
|
|
|
|
coordination = detect_coordination(comments, text_key="comment_text")
|
|
filtered, stats = filter_and_annotate(
|
|
comments, score_hackernews_comment,
|
|
min_score=config.min_credibility_score,
|
|
flag_threshold=config.flag_bot_threshold,
|
|
)
|
|
return _text_result(
|
|
_format_with_stats(filtered, stats, coordination, "HN Comments")
|
|
)
|
|
except BudgetExhaustedError as e:
|
|
return _error_result(str(e))
|
|
except Exception as e:
|
|
return _error_result(f"HN comment search failed: {e}\n{traceback.format_exc()}")
|
|
|
|
|
|
# --- Budget status tool ---
|
|
|
|
|
|
@tool(
|
|
"get_api_budget_status",
|
|
"Check remaining API call budget, rate limit status, and per-platform stats. "
|
|
"Use this before making more API calls to avoid hitting limits.",
|
|
{},
|
|
)
|
|
async def get_api_budget_status(args: dict) -> dict:
|
|
limiter = _get_limiter()
|
|
stats = limiter.get_stats()
|
|
return _text_result(json.dumps(stats, indent=2, default=str))
|
|
|
|
|
|
# --- Server factory ---
|
|
|
|
|
|
def create_social_tools_server(config: SafetyConfig | None = None):
|
|
"""Create an MCP server with all social media/forum tools.
|
|
|
|
Initializes rate limiting and credibility thresholds from config.
|
|
"""
|
|
global _limiter, _config
|
|
|
|
_config = config or SafetyConfig.from_env()
|
|
|
|
_limiter = RateLimiter(max_total_calls=_config.max_total_api_calls)
|
|
_limiter.register_platform("bluesky", _config.bluesky_rate)
|
|
_limiter.register_platform("reddit", _config.reddit_rate)
|
|
_limiter.register_platform("hackernews", _config.hackernews_rate)
|
|
|
|
return create_sdk_mcp_server(
|
|
"social-tools",
|
|
tools=[
|
|
search_bluesky,
|
|
get_bluesky_thread,
|
|
search_reddit_tool,
|
|
get_reddit_comments,
|
|
search_hackernews_tool,
|
|
search_hackernews_comments,
|
|
get_api_budget_status,
|
|
],
|
|
)
|