Skip to content
Open

AI mode #2529

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions maigret/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Maigret AI Analysis Module

Provides AI-powered analysis of search results using OpenAI-compatible APIs.
"""

import asyncio
import json
import os
import sys
import threading

import aiohttp


def load_ai_prompt() -> str:
"""Load the AI system prompt from the resources directory."""
maigret_path = os.path.dirname(os.path.realpath(__file__))
prompt_path = os.path.join(maigret_path, "resources", "ai_prompt.txt")
with open(prompt_path, "r", encoding="utf-8") as f:
return f.read()


def resolve_api_key(settings) -> str | None:
"""Resolve OpenAI API key from settings or environment variable.

Priority: settings.openai_api_key > OPENAI_API_KEY env var.
"""
key = getattr(settings, "openai_api_key", None)
if key:
return key
return os.environ.get("OPENAI_API_KEY")


class _Spinner:
"""Simple animated spinner for terminal output."""

FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]

def __init__(self, text=""):
self.text = text
self._stop = threading.Event()
self._thread = None

def start(self):
self._thread = threading.Thread(target=self._spin, daemon=True)
self._thread.start()

def _spin(self):
i = 0
while not self._stop.is_set():
frame = self.FRAMES[i % len(self.FRAMES)]
sys.stderr.write(f"\r{frame} {self.text}")
sys.stderr.flush()
i += 1
self._stop.wait(0.08)

def stop(self):
self._stop.set()
if self._thread:
self._thread.join()
sys.stderr.write("\r\033[2K")
sys.stderr.flush()


async def print_streaming(text: str, delay: float = 0.04):
"""Print text word by word with a delay, simulating streaming LLM output."""
words = text.split(" ")
for i, word in enumerate(words):
if i > 0:
sys.stdout.write(" ")
sys.stdout.write(word)
sys.stdout.flush()
await asyncio.sleep(delay)
sys.stdout.write("\n")
sys.stdout.flush()


async def get_ai_analysis(
api_key: str,
markdown_report: str,
model: str = "gpt-4o",
api_base_url: str = "https://api.openai.com/v1",
) -> str:
"""Send the markdown report to an OpenAI-compatible API and return the analysis.

Uses streaming to display tokens as they arrive.
Raises on HTTP errors with descriptive messages.
"""
system_prompt = load_ai_prompt()

url = f"{api_base_url.rstrip('/')}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"stream": True,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": markdown_report},
],
}

spinner = _Spinner("Analysing the data with AI...")
spinner.start()
first_token = True
full_response = []

try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as resp:
if resp.status == 401:
raise RuntimeError("Invalid OpenAI API key (HTTP 401)")
if resp.status == 429:
raise RuntimeError("OpenAI API rate limit exceeded (HTTP 429)")
if resp.status != 200:
body = await resp.text()
raise RuntimeError(
f"OpenAI API error (HTTP {resp.status}): {body[:500]}"
)

async for line in resp.content:
decoded = line.decode("utf-8").strip()
if not decoded or not decoded.startswith("data: "):
continue

data_str = decoded[len("data: "):]
if data_str == "[DONE]":
break

try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue

delta = chunk.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if not content:
continue

if first_token:
spinner.stop()
print()
first_token = False

sys.stdout.write(content)
sys.stdout.flush()
except Exception:
spinner.stop()
raise

if first_token:
# No tokens received — stop spinner anyway
spinner.stop()

print()
return "".join(full_response)
94 changes: 80 additions & 14 deletions maigret/maigret.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,21 @@ def setup_arguments_parser(settings: Settings):
" (one report per username).",
)

report_group.add_argument(
"--ai",
action="store_true",
dest="ai",
default=False,
help="Generate an AI-powered analysis of the search results using OpenAI API. "
"Requires OPENAI_API_KEY env var or openai_api_key in settings.",
)
report_group.add_argument(
"--ai-model",
dest="ai_model",
default=settings.openai_model,
help="OpenAI model to use for AI analysis (default: gpt-4o).",
)

parser.add_argument(
"--reports-sorting",
default=settings.report_sorting,
Expand Down Expand Up @@ -596,6 +611,7 @@ async def main():
print_found_only=not args.print_not_found,
skip_check_errors=not args.print_check_errors,
color=not args.no_color,
silent=args.ai,
)

# Create object with all information about sites we are aware of.
Expand Down Expand Up @@ -711,17 +727,33 @@ async def main():
+ get_dict_ascii_tree(usernames, prepend="\t")
)

if args.ai:
from .ai import resolve_api_key

if not resolve_api_key(settings):
query_notify.warning(
'AI analysis requires an OpenAI API key. '
'Set OPENAI_API_KEY environment variable or add '
'openai_api_key to settings.json.'
)
sys.exit(1)

if not site_data:
query_notify.warning('No sites to check, exiting!')
sys.exit(2)

query_notify.warning(
f'Starting a search on top {len(site_data)} sites from the Maigret database...'
)
if not args.all_sites:
if args.ai:
query_notify.warning(
f'Starting AI-assisted search on top {len(site_data)} sites from the Maigret database...'
)
else:
query_notify.warning(
'You can run search by full list of sites with flag `-a`', '!'
f'Starting a search on top {len(site_data)} sites from the Maigret database...'
)
if not args.all_sites:
query_notify.warning(
'You can run search by full list of sites with flag `-a`', '!'
)

already_checked = set()
general_results = []
Expand Down Expand Up @@ -774,11 +806,12 @@ async def main():
check_domains=args.with_domains,
)

errs = errors.notify_about_errors(
results, query_notify, show_statistics=args.verbose
)
for e in errs:
query_notify.warning(*e)
if not args.ai:
errs = errors.notify_about_errors(
results, query_notify, show_statistics=args.verbose
)
for e in errs:
query_notify.warning(*e)

if args.reports_sorting == "data":
results = sort_report_by_data_points(results)
Expand Down Expand Up @@ -867,10 +900,43 @@ async def main():
save_graph_report(filename, general_results, db)
query_notify.warning(f'Graph report on all usernames saved in {filename}')

text_report = get_plaintext_report(report_context)
if text_report:
query_notify.info('Short text report:')
print(text_report)
if not args.ai:
text_report = get_plaintext_report(report_context)
if text_report:
query_notify.info('Short text report:')
print(text_report)

if args.ai:
from .ai import get_ai_analysis, resolve_api_key
from .report import generate_markdown_report

api_key = resolve_api_key(settings)

run_flags = []
if args.tags:
run_flags.append(f"--tags {args.tags}")
if args.site_list:
run_flags.append(f"--site {','.join(args.site_list)}")
if args.all_sites:
run_flags.append("--all-sites")
run_info = {
"sites_count": sum(len(d) for _, _, d in general_results),
"flags": " ".join(run_flags) if run_flags else None,
}

md_report = generate_markdown_report(report_context, run_info=run_info)

try:
await get_ai_analysis(
api_key=api_key,
markdown_report=md_report,
model=args.ai_model,
api_base_url=getattr(
settings, 'openai_api_base_url', 'https://api.openai.com/v1'
),
)
except Exception as e:
query_notify.warning(f'AI analysis failed: {e}')

# update database
db.save_to_file(db_file)
Expand Down
8 changes: 8 additions & 0 deletions maigret/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
print_found_only=False,
skip_check_errors=False,
color=True,
silent=False,
):
"""Create Query Notify Print Object.

Expand All @@ -149,6 +150,7 @@ def __init__(
self.print_found_only = print_found_only
self.skip_check_errors = skip_check_errors
self.color = color
self.silent = silent

return

Expand Down Expand Up @@ -187,6 +189,9 @@ def start(self, message=None, id_type="username"):
Nothing.
"""

if self.silent:
return

title = f"Checking {id_type}"
if self.color:
print(
Expand Down Expand Up @@ -236,6 +241,9 @@ def update(self, result, is_similar=False):
Return Value:
Nothing.
"""
if self.silent:
return

notify = None
self.result = result

Expand Down
9 changes: 7 additions & 2 deletions maigret/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _md_format_value(value) -> str:
return s


def save_markdown_report(filename: str, context: dict, run_info: dict = None):
def generate_markdown_report(context: dict, run_info: dict = None) -> str:
username = context.get("username", "unknown")
generated_at = context.get("generated_at", "")
brief = context.get("brief", "")
Expand Down Expand Up @@ -391,8 +391,13 @@ def save_markdown_report(filename: str, context: dict, run_info: dict = None):
"CCPA, and similar).\n"
)

return "\n".join(lines)


def save_markdown_report(filename: str, context: dict, run_info: dict = None):
content = generate_markdown_report(context, run_info)
with open(filename, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
f.write(content)


"""
Expand Down
Loading
Loading