diff --git a/router/router.py b/router/router.py index ba229fb..2b03199 100644 --- a/router/router.py +++ b/router/router.py @@ -160,8 +160,19 @@ def is_gpu_busy(model): max_c = GPU_MAX_CONCURRENT.get(model, 1) return active >= max_c -def select_best_gpu(candidates, reason): - """Pick the best GPU from candidates IN ORDER — first non-busy one wins.""" +def select_best_gpu(candidates, reason, agent=""): + """Pick best GPU, spreading different agents across GPUs when possible.""" + # Track which GPUs this agent is already using + agent_gpus = set() + if agent and r: + for m in GPU_URLS: + if r.get("agent_gpu:" + agent + ":" + m): + agent_gpus.add(m) + # First pass: prefer GPUs NOT used by this agent + for m in candidates: + if not is_gpu_busy(m) and m not in agent_gpus: + return {"model": m, "reason": reason} + # Second pass: any non-busy GPU (agent reuse is ok) for m in candidates: if not is_gpu_busy(m): return {"model": m, "reason": reason} @@ -177,7 +188,7 @@ def select_best_gpu(candidates, reason): return {"model": best, "reason": "load_balanced_" + reason} return None -def route(rd, tier): +def route(rd, tier, agent=""): msgs = rd.get("messages",[]); t = estimate_tokens(msgs) sys = any(m.get("role")=="system" for m in msgs) turns = len([m for m in msgs if m.get("role") in ("user","assistant")]) @@ -196,15 +207,15 @@ def route(rd, tier): if is_gpu_busy(target) and req in allowed: alts = [m for m in avail if m != target and m in allowed] if alts: - alt = select_best_gpu(alts, "explicit") + alt = select_best_gpu(alts, "explicit", agent) if alt: return alt return {"model": target, "reason": "explicit"} if hints: if hints.get("priority")=="speed" and "qwen3.5-9b-vlm" in avail: - return select_best_gpu(["qwen3.5-9b-vlm"], "hint_speed") or {"model":"qwen3.5-9b-vlm","reason":"hint_speed"} + return select_best_gpu(["qwen3.5-9b-vlm"], "hint_speed", agent) or {"model":"qwen3.5-9b-vlm","reason":"hint_speed"} if hints.get("priority")=="quality" and "qwen3.6-35B-A3B" in avail: - return select_best_gpu(["qwen3.6-35B-A3B"], "hint_quality") or {"model":"qwen3.6-35B-A3B","reason":"hint_quality"} + return select_best_gpu(["qwen3.6-35B-A3B"], "hint_quality", agent) or {"model":"qwen3.6-35B-A3B","reason":"hint_quality"} first_msg = msgs[0].get("content","") if msgs else "" words = len(first_msg.split()) if isinstance(first_msg, str) else 99 @@ -215,7 +226,7 @@ def route(rd, tier): return {"model":"qwen3.5-9b-vlm","reason":"lightweight"} # VLM busy — Dense is faster for short queries than MoE fallback = [m for m in ["qwen3.6-27B-code","qwen3.6-35B-A3B"] if m in avail] - result = select_best_gpu(fallback, "lightweight_fallback") + result = select_best_gpu(fallback, "lightweight_fallback", agent) if result: return result # TIER 2: Simple conversations — VLM primary (up to 15K tok), fastest for moderate chat @@ -224,24 +235,24 @@ def route(rd, tier): return {"model":"qwen3.5-9b-vlm","reason":"simple_conv"} # VLM busy — fall back to Dense, then MoE fallback = [m for m in ["qwen3.6-27B-code","qwen3.6-35B-A3B"] if m in avail] - result = select_best_gpu(fallback, "simple_conv_fallback") + result = select_best_gpu(fallback, "simple_conv_fallback", agent) if result: return result # TIER 3: Medium complexity — Dense primary, VLM fallback (quality + speed balance) if t <= 25000: candidates = [m for m in ["qwen3.6-27B-code","qwen3.5-9b-vlm","qwen3.6-35B-A3B"] if m in avail] - result = select_best_gpu(candidates, "medium") + result = select_best_gpu(candidates, "medium", agent) if result: return result # TIER 4: Heavy reasoning — MoE primary (workhorse), Dense fallback if t > 25000: candidates = [m for m in ["qwen3.6-35B-A3B","qwen3.6-27B-code","qwen3.5-9b-vlm"] if m in avail] - result = select_best_gpu(candidates, "heavy_reasoning") + result = select_best_gpu(candidates, "heavy_reasoning", agent) if result: return result # TIER 5: Default — Dense primary, MoE fallback candidates = [m for m in ["qwen3.6-27B-code","qwen3.5-9b-vlm","qwen3.6-35B-A3B"] if m in avail] - result = select_best_gpu(candidates, "default") + result = select_best_gpu(candidates, "default", agent) if result: return result return {"model":avail[0],"reason":"last_resort"} @@ -313,7 +324,7 @@ def chat(): r.set("session:" + session_id, session_tokens, ex=86400) # TTL 24h except Exception: pass - d = route(rd, tier) + d = route(rd, tier, agent) queue_start = time.time() # Queue loop: wait for a GPU slot instead of immediate 503 @@ -325,7 +336,7 @@ def chat(): log.warning("QUEUE_TIMEOUT: %s waited %.1fs, all GPUs saturated", agent, elapsed) return resp, 503 time.sleep(0.5) # poll every 500ms - d = route(rd, tier) + d = route(rd, tier, agent) queue_ms = (time.time() - queue_start) * 1000 if queue_ms > 500: @@ -336,6 +347,10 @@ def chat(): gpu_incr(model) log.info("ROUTE: %s -> %s (%s) stream=%s active=%d/%d", agent, model, reason, is_stream, gpu_active_count(model), GPU_MAX_CONCURRENT.get(model,1)) + # Track which GPU this agent is using (TTL 120s covers typical request) + if r and agent: + try: r.setex("agent_gpu:" + agent + ":" + model, 120, "1") + except: pass if r: try: r.incr("routes:"+model); r.incr("routes:tier:"+tier); r.incr("routes:agent:"+agent)