diff --git a/backend/main.py b/backend/main.py index 93d2950..f7a6277 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6237,6 +6237,39 @@ class Session: self.relationship_context: dict[str, str] = {} # caller_key → relationship prompt injection self.intern_monitoring: bool = True # Devon monitors conversations by default self.show_theme: str = "" # Current show theme (e.g. "St. Patrick's Day") + # Caller model routing + self.caller_model_strategy: str = "single" # "single" | "cycle" | "style_matched" + self.caller_model_pool: list[str] = [ + "x-ai/grok-4", + "anthropic/claude-sonnet-4-5", + "mistralai/mistral-medium-3", + "qwen/qwen3-235b-a22b", + "deepseek/deepseek-chat-v3-0324", + "google/gemini-2.5-pro", + "meta-llama/llama-4-maverick", + ] + self.caller_model_map: dict[str, str] = { + "high_energy": "x-ai/grok-4", + "confrontational": "x-ai/grok-4", + "angry_venting": "x-ai/grok-4", + "bragger": "x-ai/grok-4", + "comedian": "x-ai/grok-4", + "quiet_nervous": "anthropic/claude-sonnet-4-5", + "sweet_earnest": "anthropic/claude-sonnet-4-5", + "emotional": "anthropic/claude-sonnet-4-5", + "deadpan": "mistralai/mistral-medium-3", + "mysterious": "mistralai/mistral-medium-3", + "world_weary": "mistralai/mistral-medium-3", + "storyteller": "qwen/qwen3-235b-a22b", + "rambling": "qwen/qwen3-235b-a22b", + "oversharer": "deepseek/deepseek-chat-v3-0324", + "conspiracy": "deepseek/deepseek-chat-v3-0324", + "know_it_all": "google/gemini-2.5-pro", + "first_time": "meta-llama/llama-4-maverick", + } + self.caller_model_fallback: str = "anthropic/claude-sonnet-4-5" + self.caller_models: dict[str, str] = {} # caller_key → assigned model + self._caller_model_cycle_idx: int = 0 def start_call(self, caller_key: str): self.current_caller_key = caller_key @@ -6253,6 +6286,35 @@ class Session: def add_message(self, role: str, content: str): self.conversation.append({"role": role, "content": content, "timestamp": time.time()}) + def get_caller_model(self, caller_key: str) -> str | None: + """Get the assigned model for a caller, or assign one based on strategy. + Returns None to use default category routing.""" + if self.caller_model_strategy == "single": + return None # use default category_models["caller_dialog"] + + # Already assigned — keep consistent for the whole call + if caller_key in self.caller_models: + return self.caller_models[caller_key] + + model = None + if self.caller_model_strategy == "cycle": + if self.caller_model_pool: + model = self.caller_model_pool[self._caller_model_cycle_idx % len(self.caller_model_pool)] + self._caller_model_cycle_idx += 1 + elif self.caller_model_strategy == "style_matched": + raw_style = self.caller_styles.get(caller_key, "") + style_key = _normalize_style_key(raw_style) if raw_style else "" + model = self.caller_model_map.get(style_key) + if not model and self.caller_model_pool: + model = self.caller_model_pool[0] + + if model: + self.caller_models[caller_key] = model + caller_name = CALLER_BASES.get(caller_key, {}).get("name", caller_key) + print(f"[CallerModel] Assigned {model} to {caller_name} (strategy={self.caller_model_strategy})") + + return model + def get_caller_background(self, caller_key: str) -> str: """Get or generate background for a caller in this session. Returns the natural_description string for prompt injection.""" @@ -6607,6 +6669,12 @@ def _save_checkpoint(): "caller_queue": session.caller_queue, "relationship_context": session.relationship_context, "intern_monitoring": session.intern_monitoring, + "caller_model_strategy": session.caller_model_strategy, + "caller_model_pool": session.caller_model_pool, + "caller_model_map": session.caller_model_map, + "caller_model_fallback": session.caller_model_fallback, + "caller_models": session.caller_models, + "caller_model_cycle_idx": session._caller_model_cycle_idx, "costs": cost_tracker.get_live_summary(), "cost_records": { "llm": [asdict(r) for r in cost_tracker.llm_records], @@ -6653,6 +6721,12 @@ def _load_checkpoint() -> bool: session.caller_queue = data.get("caller_queue", []) session.relationship_context = data.get("relationship_context", {}) session.intern_monitoring = data.get("intern_monitoring", True) + session.caller_model_strategy = data.get("caller_model_strategy", "single") + session.caller_model_pool = data.get("caller_model_pool", ["anthropic/claude-sonnet-4-5"]) + session.caller_model_map = data.get("caller_model_map", {}) + session.caller_model_fallback = data.get("caller_model_fallback", "anthropic/claude-sonnet-4-5") + session.caller_models = data.get("caller_models", {}) + session._caller_model_cycle_idx = data.get("caller_model_cycle_idx", 0) for key, snapshot in data.get("caller_bases", {}).items(): if key in CALLER_BASES: CALLER_BASES[key]["name"] = snapshot["name"] @@ -8563,6 +8637,7 @@ async def chat(request: ChatRequest): max_tokens=max_tokens, category="caller_dialog", caller_name=session.caller.get("name", "") if session.caller else "", + model_override=session.get_caller_model(session.current_caller_key) if session.current_caller_key else None, ) # Discard if call changed while we were generating @@ -8953,6 +9028,74 @@ async def set_show_theme(data: dict): return {"theme": session.show_theme} +# --- Caller Model Routing --- + +@app.get("/api/caller-models") +async def get_caller_models(): + """Get current caller model routing config and per-caller assignments.""" + assignments = {} + for key in CALLER_BASES: + name = CALLER_BASES[key].get("name", key) + model = session.caller_models.get(key) + assignments[key] = {"name": name, "model": model or "(default)"} + return { + "strategy": session.caller_model_strategy, + "pool": session.caller_model_pool, + "map": session.caller_model_map, + "fallback": session.caller_model_fallback, + "assignments": assignments, + } + + +@app.post("/api/caller-models") +async def set_caller_models(data: dict): + """Update caller model routing strategy, pool, map, or fallback.""" + if "strategy" in data: + strategy = data["strategy"] + if strategy not in ("single", "cycle", "style_matched"): + raise HTTPException(400, f"Invalid strategy: {strategy}") + session.caller_model_strategy = strategy + print(f"[CallerModel] Strategy set to: {strategy}") + if "pool" in data: + pool = data["pool"] + if not isinstance(pool, list) or not pool: + raise HTTPException(400, "pool must be a non-empty list of model IDs") + session.caller_model_pool = pool + print(f"[CallerModel] Pool set to: {pool}") + if "map" in data: + session.caller_model_map = data["map"] + print(f"[CallerModel] Style map set: {len(data['map'])} entries") + if "fallback" in data: + session.caller_model_fallback = data["fallback"] + print(f"[CallerModel] Fallback set to: {data['fallback']}") + # Clear existing assignments so new strategy takes effect + if "strategy" in data or "pool" in data or "map" in data: + session.caller_models.clear() + session._caller_model_cycle_idx = 0 + print(f"[CallerModel] Cleared caller assignments (new config)") + _save_checkpoint() + return await get_caller_models() + + +@app.post("/api/caller-models/{caller_key}") +async def set_caller_model_override(caller_key: str, data: dict): + """Override the model for a specific caller mid-show.""" + if caller_key not in CALLER_BASES: + raise HTTPException(404, f"Unknown caller key: {caller_key}") + model = data.get("model", "").strip() + if not model: + # Clear override + session.caller_models.pop(caller_key, None) + name = CALLER_BASES[caller_key].get("name", caller_key) + print(f"[CallerModel] Cleared override for {name}") + else: + session.caller_models[caller_key] = model + name = CALLER_BASES[caller_key].get("name", caller_key) + print(f"[CallerModel] Override {name} → {model}") + _save_checkpoint() + return {"caller_key": caller_key, "model": session.caller_models.get(caller_key, "(default)")} + + # --- Cost Tracking Endpoints --- @app.get("/api/costs") @@ -9442,6 +9585,7 @@ async def _trigger_ai_auto_respond(accumulated_text: str): max_tokens=max_tokens, category="caller_dialog", caller_name=session.caller.get("name", "") if session.caller else "", + model_override=session.get_caller_model(session.current_caller_key) if session.current_caller_key else None, ) # Discard if call changed during generation @@ -9543,6 +9687,7 @@ async def ai_respond(): max_tokens=max_tokens, category="caller_dialog", caller_name=session.caller.get("name", "") if session.caller else "", + model_override=session.get_caller_model(session.current_caller_key) if session.current_caller_key else None, ) if _session_epoch != epoch: diff --git a/backend/services/cost_tracker.py b/backend/services/cost_tracker.py index 3b9087b..c3d2c4c 100644 --- a/backend/services/cost_tracker.py +++ b/backend/services/cost_tracker.py @@ -45,6 +45,12 @@ OPENROUTER_PRICING = { "openai/gpt-4o-mini": {"prompt": 0.15, "completion": 0.60}, "openai/gpt-4o": {"prompt": 2.50, "completion": 10.00}, "meta-llama/llama-3.1-8b-instruct": {"prompt": 0.06, "completion": 0.06}, + "deepseek/deepseek-chat-v3-0324": {"prompt": 0.27, "completion": 1.10}, + "moonshotai/kimi-k2": {"prompt": 0.60, "completion": 2.00}, + "mistralai/mistral-medium-3": {"prompt": 0.40, "completion": 2.00}, + "meta-llama/llama-4-maverick": {"prompt": 0.20, "completion": 0.60}, + "qwen/qwen3-235b-a22b": {"prompt": 0.20, "completion": 0.60}, + "google/gemini-2.5-pro": {"prompt": 1.25, "completion": 10.00}, } # TTS pricing per character diff --git a/backend/services/llm.py b/backend/services/llm.py index c56baeb..3e5704f 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -23,6 +23,13 @@ OPENROUTER_MODELS = [ "google/gemini-2.5-flash", "openai/gpt-4o-mini", "openai/gpt-4o", + # New dialog models + "deepseek/deepseek-chat-v3-0324", + "moonshotai/kimi-k2", + "mistralai/mistral-medium-3", + "meta-llama/llama-4-maverick", + "qwen/qwen3-235b-a22b", + "google/gemini-2.5-pro", # Legacy "anthropic/claude-3-haiku", "google/gemini-flash-1.5", @@ -125,12 +132,13 @@ class LLMService: response_format: Optional[dict] = None, category: str = "unknown", caller_name: str = "", + model_override: Optional[str] = None, ) -> str: if system_prompt: messages = [{"role": "system", "content": system_prompt}] + messages if self.provider == "openrouter": - return await self._call_openrouter_with_fallback(messages, max_tokens=max_tokens, response_format=response_format, category=category, caller_name=caller_name) + return await self._call_openrouter_with_fallback(messages, max_tokens=max_tokens, response_format=response_format, category=category, caller_name=caller_name, model_override=model_override) else: return await self._call_ollama(messages, max_tokens=max_tokens) @@ -295,11 +303,11 @@ class LLMService: """Get the best model for a given category based on config routing.""" return settings.category_models.get(category, self.openrouter_model) - async def _call_openrouter_with_fallback(self, messages: list[dict], max_tokens: Optional[int] = None, response_format: Optional[dict] = None, category: str = "unknown", caller_name: str = "") -> str: + async def _call_openrouter_with_fallback(self, messages: list[dict], max_tokens: Optional[int] = None, response_format: Optional[dict] = None, category: str = "unknown", caller_name: str = "", model_override: Optional[str] = None) -> str: """Try category-specific model, then fallback models. Always returns a response.""" - # Use category-specific model if configured, otherwise primary - model = self._get_model_for_category(category) + # Use explicit override if provided, else category routing, else primary + model = model_override or self._get_model_for_category(category) result = await self._call_openrouter_once(messages, model, max_tokens=max_tokens, response_format=response_format, category=category, caller_name=caller_name) if result is not None: return result diff --git a/frontend/css/style.css b/frontend/css/style.css index c254f5b..832dceb 100644 --- a/frontend/css/style.css +++ b/frontend/css/style.css @@ -463,6 +463,84 @@ section h2 { line-height: 1.3; } +/* Caller model indicator */ +.info-badge.model { + background: rgba(100, 140, 220, 0.2); + color: #7ab0e8; + font-size: 0.7rem; + cursor: pointer; +} + +.caller-model-override { + font-size: 0.7rem; + padding: 2px 4px; + background: var(--bg); + color: var(--text); + border: 1px solid rgba(100, 140, 220, 0.3); + border-radius: 4px; + max-width: 140px; +} + +/* Caller button model badge */ +.model-tag { + font-size: 0.55rem; + color: #7ab0e8; + background: rgba(100, 140, 220, 0.15); + padding: 0 3px; + border-radius: 2px; + font-weight: 700; + letter-spacing: 0.3px; + flex-shrink: 0; +} + +/* Caller Models settings section */ +.caller-model-row { + margin-bottom: 8px; +} + +.caller-model-row label { + margin-bottom: 0; +} + +.cm-pool-input { + font-size: 0.8rem; +} + +.cm-style-grid { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 4px; + margin-bottom: 8px; + max-height: 200px; + overflow-y: auto; +} + +.cm-style-item { + display: flex; + align-items: center; + justify-content: space-between; + gap: 4px; + background: rgba(255, 255, 255, 0.05); + border-radius: 4px; + padding: 3px 6px; +} + +.cm-style-name { + font-size: 0.7rem; + color: var(--text-muted); + white-space: nowrap; +} + +.cm-style-select { + font-size: 0.7rem; + padding: 2px 3px; + background: var(--bg); + color: var(--text); + border: 1px solid rgba(232, 121, 29, 0.15); + border-radius: 4px; + max-width: 110px; +} + .caller-background-full { margin-top: 8px; font-size: 0.75rem; diff --git a/frontend/index.html b/frontend/index.html index 7c9e8ed..e102df4 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -75,6 +75,8 @@ + +
@@ -285,6 +287,36 @@ + +