diff --git a/backend/main.py b/backend/main.py index c86905f..b233750 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2,6 +2,7 @@ import uuid import asyncio +from dataclasses import dataclass, field from pathlib import Path from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.staticfiles import StaticFiles @@ -294,12 +295,24 @@ OUTPUT: Spoken words only. No (actions), no *gestures*, no stage directions.""" # --- Session State --- +@dataclass +class CallRecord: + caller_type: str # "ai" or "real" + caller_name: str # "Tony" or "Caller #3" + summary: str # LLM-generated summary after hangup + transcript: list[dict] = field(default_factory=list) + + class Session: def __init__(self): self.id = str(uuid.uuid4())[:8] self.current_caller_key: str = None self.conversation: list[dict] = [] self.caller_backgrounds: dict[str, str] = {} # Generated backgrounds for this session + self.call_history: list[CallRecord] = [] + self.active_real_caller: dict | None = None + self.ai_respond_mode: str = "manual" # "manual" or "auto" + self.auto_followup: bool = False def start_call(self, caller_key: str): self.current_caller_key = caller_key @@ -321,15 +334,39 @@ class Session: print(f"[Session {self.id}] Generated background for {base['name']}: {self.caller_backgrounds[caller_key][:100]}...") return self.caller_backgrounds.get(caller_key, "") + def get_show_history(self) -> str: + """Get formatted show history for AI caller prompts""" + if not self.call_history: + return "" + lines = ["EARLIER IN THE SHOW:"] + for record in self.call_history: + caller_type_label = "(real caller)" if record.caller_type == "real" else "(AI)" + lines.append(f"- {record.caller_name} {caller_type_label}: {record.summary}") + lines.append("You can reference these if it feels natural. Don't force it.") + return "\n".join(lines) + def get_conversation_summary(self) -> str: """Get a brief summary of conversation so far for context""" if len(self.conversation) <= 2: return "" - # Just include the key exchanges, not the full history summary_parts = [] - for msg in self.conversation[-6:]: # Last 3 exchanges - role = "Host" if msg["role"] == "user" else self.caller["name"] - summary_parts.append(f'{role}: "{msg["content"][:100]}..."' if len(msg["content"]) > 100 else f'{role}: "{msg["content"]}"') + for msg in self.conversation[-6:]: + role = msg["role"] + if role == "user" or role == "host": + label = "Host" + elif role.startswith("real_caller:"): + label = role.split(":", 1)[1] + elif role.startswith("ai_caller:"): + label = role.split(":", 1)[1] + elif role == "assistant": + label = self.caller["name"] if self.caller else "Caller" + else: + label = role + content = msg["content"] + summary_parts.append( + f'{label}: "{content[:100]}..."' if len(content) > 100 + else f'{label}: "{content}"' + ) return "\n".join(summary_parts) @property @@ -349,6 +386,10 @@ class Session: self.caller_backgrounds = {} self.current_caller_key = None self.conversation = [] + self.call_history = [] + self.active_real_caller = None + self.ai_respond_mode = "manual" + self.auto_followup = False self.id = str(uuid.uuid4())[:8] print(f"[Session] Reset - new session ID: {self.id}") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..35b5b39 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,89 @@ +import sys +sys.path.insert(0, "/Users/lukemacneil/ai-podcast") + +from backend.main import Session, CallRecord + + +def test_call_record_creation(): + record = CallRecord( + caller_type="real", + caller_name="Dave", + summary="Called about his wife leaving", + transcript=[{"role": "host", "content": "What happened?"}], + ) + assert record.caller_type == "real" + assert record.caller_name == "Dave" + + +def test_session_call_history(): + s = Session() + assert s.call_history == [] + record = CallRecord( + caller_type="ai", caller_name="Tony", + summary="Talked about gambling", transcript=[], + ) + s.call_history.append(record) + assert len(s.call_history) == 1 + + +def test_session_active_real_caller(): + s = Session() + assert s.active_real_caller is None + s.active_real_caller = { + "call_sid": "CA123", "phone": "+15125550142", + "channel": 3, "name": "Caller #1", + } + assert s.active_real_caller["channel"] == 3 + + +def test_session_three_party_conversation(): + s = Session() + s.start_call("1") # AI caller Tony + s.add_message("host", "Hey Tony") + s.add_message("ai_caller:Tony", "What's up man") + s.add_message("real_caller:Dave", "Yeah I agree with Tony") + assert len(s.conversation) == 3 + assert s.conversation[2]["role"] == "real_caller:Dave" + + +def test_session_get_show_history_summary(): + s = Session() + s.call_history.append(CallRecord( + caller_type="real", caller_name="Dave", + summary="Called about his wife leaving after 12 years", + transcript=[], + )) + s.call_history.append(CallRecord( + caller_type="ai", caller_name="Jasmine", + summary="Talked about her boss hitting on her", + transcript=[], + )) + summary = s.get_show_history() + assert "Dave" in summary + assert "Jasmine" in summary + assert "EARLIER IN THE SHOW" in summary + + +def test_session_reset_clears_history(): + s = Session() + s.call_history.append(CallRecord( + caller_type="real", caller_name="Dave", + summary="test", transcript=[], + )) + s.active_real_caller = {"call_sid": "CA123"} + s.ai_respond_mode = "auto" + s.reset() + assert s.call_history == [] + assert s.active_real_caller is None + assert s.ai_respond_mode == "manual" + + +def test_session_conversation_summary_three_party(): + s = Session() + s.start_call("1") + s.add_message("host", "Tell me what happened") + s.add_message("real_caller:Dave", "She just left man") + s.add_message("ai_caller:Tony", "Same thing happened to me") + summary = s.get_conversation_summary() + assert "Dave" in summary + assert "Tony" in summary