diff --git a/router/router.py b/router/router.py index f47e0de..d3c183c 100644 --- a/router/router.py +++ b/router/router.py @@ -17,6 +17,13 @@ GPU_URLS = { "qwen3.6-27B-code": GPU_DENSE_URL, "gemma-4-E4B": GPU_LIGHT_URL, } +# Max concurrent requests per GPU (based on llama.cpp --parallel) +GPU_MAX_CONCURRENT = { + "qwen3.6-35B-A3B": 2, # 2 slots + "qwen3.6-27B-code": 2, # 2 slots + "gemma-4-E4B": 1, # 1 slot +} + TIER_MODELS = { "starter": ["gemma-4-E4B"], "professional": ["qwen3.6-35B-A3B", "qwen3.6-27B-code", "gemma-4-E4B"], @@ -41,6 +48,18 @@ except Exception: r = None app = Flask(__name__) sse_subscribers = []; sse_lock = threading.Lock() +def gpu_active_count(model): + """Get number of in-flight requests for a GPU.""" + if r: + return int(r.get("active:" + model) or 0) + return 0 + +def gpu_incr(model): + if r: r.incr("active:" + model) + +def gpu_decr(model): + if r: r.decr("active:" + model) + def check_gpu_health(model): url = GPU_SIDECARS.get(model) if not url: return {"status": "unknown"} @@ -57,6 +76,28 @@ def available_models(): return [m for m in GPU_URLS if check_gpu_health(m)["stat def estimate_tokens(msgs): return sum(len(str(m.get("content",""))) for m in msgs) // 4 +def is_gpu_busy(model): + """Check if GPU is at or near max concurrent capacity.""" + active = gpu_active_count(model) + max_c = GPU_MAX_CONCURRENT.get(model, 1) + return active >= max_c + +def select_best_gpu(candidates, reason): + """Pick the best GPU from candidates, preferring least-loaded.""" + best = None + best_load = 999 + for m in candidates: + load = gpu_active_count(m) + if load < best_load: + best_load = load + best = m + if best: + actual_reason = reason + if is_gpu_busy(best): + actual_reason = "load_balanced_" + reason + return {"model": best, "reason": actual_reason} + return None + def route(rd, tier): msgs = rd.get("messages",[]); t = estimate_tokens(msgs) sys = any(m.get("role")=="system" for m in msgs) @@ -65,34 +106,52 @@ def route(rd, tier): allowed = TIER_MODELS.get(tier, ["gemma-4-E4B"]) avail = [m for m in available_models() if m in allowed] if not avail: return {"model": allowed[0], "reason": "all_saturated"} + req = rd.get("model","auto") - if req != "auto": return {"model": req if req in avail else avail[0], "reason": "explicit"} + if req != "auto": + target = req if req in avail else avail[0] + # If explicit model is busy, check if another can take it + if is_gpu_busy(target) and req in allowed: + alts = [m for m in avail if m != target and m in allowed] + if alts: + alt = select_best_gpu(alts, "explicit") + if alt: return alt + return {"model": target, "reason": "explicit"} + if hints: - if hints.get("priority")=="speed" and "gemma-4-E4B" in avail: return {"model":"gemma-4-E4B","reason":"hint_speed"} - if hints.get("priority")=="quality" and "qwen3.6-27B-code" in avail: return {"model":"qwen3.6-27B-code","reason":"hint_quality"} + if hints.get("priority")=="speed" and "gemma-4-E4B" in avail: + return select_best_gpu(["gemma-4-E4B"], "hint_speed") or {"model":"gemma-4-E4B","reason":"hint_speed"} + if hints.get("priority")=="quality" and "qwen3.6-27B-code" in avail: + return select_best_gpu(["qwen3.6-27B-code"], "hint_quality") or {"model":"qwen3.6-27B-code","reason":"hint_quality"} + + # Heavy -> dense (but fall back to MoE if dense is busy) if t > 4000 or sys or turns > 6: - for m in ["qwen3.6-27B-code","qwen3.6-35B-A3B","gemma-4-E4B"]: - if m in avail: return {"model":m,"reason":"heavy_reasoning"} + candidates = ["qwen3.6-27B-code","qwen3.6-35B-A3B","gemma-4-E4B"] + candidates = [m for m in candidates if m in avail] + result = select_best_gpu(candidates, "heavy_reasoning") + if result: return result + + # Ultra-light -> gemma first_msg = msgs[0].get("content","") if msgs else "" words = len(first_msg.split()) if isinstance(first_msg, str) else 99 if words <= 3 and turns <= 1 and not sys and "gemma-4-E4B" in avail: - return {"model":"gemma-4-E4B","reason":"ultra_light"} - if "qwen3.6-35B-A3B" in avail: return {"model":"qwen3.6-35B-A3B","reason":"default_moe"} + if not is_gpu_busy("gemma-4-E4B"): + return {"model":"gemma-4-E4B","reason":"ultra_light"} + + # Default: MoE, fall back to dense if MoE is busy + if "qwen3.6-35B-A3B" in avail: + if is_gpu_busy("qwen3.6-35B-A3B") and "qwen3.6-27B-code" in avail: + return {"model": "qwen3.6-27B-code", "reason": "load_balanced_default"} + return {"model":"qwen3.6-35B-A3B","reason":"default_moe"} + return {"model":avail[0],"reason":"fallback"} def clean_unicode(text): - if not isinstance(text, str): - return text - # Replace common Unicode punctuation with ASCII equivalents - text = text.replace(chr(0x2014), "-") - text = text.replace(chr(0x2013), "-") - text = text.replace(chr(0x2018), "'") - text = text.replace(chr(0x2019), "'") - text = text.replace(chr(0x201C), '"') - text = text.replace(chr(0x201D), '"') - text = text.replace(chr(0x2026), "...") - text = text.replace(chr(0x00A0), " ") - # Strip ALL remaining non-ASCII (emoji, symbols) + if not isinstance(text, str): return text + text = text.replace(chr(0x2014), "-"); text = text.replace(chr(0x2013), "-") + text = text.replace(chr(0x2018), "'"); text = text.replace(chr(0x2019), "'") + text = text.replace(chr(0x201C), '"'); text = text.replace(chr(0x201D), '"') + text = text.replace(chr(0x2026), "..."); text = text.replace(chr(0x00A0), " ") return text.encode("ascii", "ignore").decode("ascii") def clean_response(d): @@ -102,10 +161,11 @@ def clean_response(d): return d def get_metrics(): - d = {"gpus":[],"route_counts":{},"agent_counts":{},"tier_counts":{},"recent":[],"timestamp":time.time()} + d = {"gpus":[],"route_counts":{},"agent_counts":{},"tier_counts":{},"recent":[],"timestamp":time.time(),"active_requests":{}} for m in GPU_URLS: h = check_gpu_health(m) - d["gpus"].append({"id":m,"gpu_name":h.get("gpu_name",m),"status":h.get("status"),"vram_used_mb":h.get("vram_used_mb"),"vram_total_mb":h.get("vram_total_mb"),"vram_pct":h.get("vram_pct"),"temp_c":h.get("temp_c"),"gpu_util_pct":h.get("gpu_util_pct"),"power_w":h.get("power_w"),"power_limit_w":h.get("power_limit_w")}) + d["gpus"].append({"id":m,"gpu_name":h.get("gpu_name",m),"status":h.get("status"),"vram_used_mb":h.get("vram_used_mb"),"vram_total_mb":h.get("vram_total_mb"),"vram_pct":h.get("vram_pct"),"temp_c":h.get("temp_c"),"gpu_util_pct":h.get("gpu_util_pct"),"power_w":h.get("power_w"),"power_limit_w":h.get("power_limit_w"),"active_requests":gpu_active_count(m)}) + d["active_requests"][m] = gpu_active_count(m) if r: try: for m in GPU_URLS: d["route_counts"][m] = int(r.get("routes:"+m) or 0) @@ -136,7 +196,10 @@ def chat(): tier, agent = ki["tier"], ki["agent"] d = route(rd, tier); model, reason, url = d["model"], d["reason"], GPU_URLS[d["model"]] is_stream = rd.get("stream", False) - log.info("ROUTE: %s -> %s (%s) stream=%s", agent, model, reason, is_stream) + + gpu_incr(model) # Track active request + + log.info("ROUTE: %s -> %s (%s) stream=%s active=%d/%d", agent, model, reason, is_stream, gpu_active_count(model), GPU_MAX_CONCURRENT.get(model,1)) if r: try: r.incr("routes:"+model); r.incr("routes:tier:"+tier); r.incr("routes:agent:"+agent) @@ -148,6 +211,8 @@ def chat(): resp = requests.post(url+"/chat/completions", json=rd, headers={"Content-Type":"application/json","Authorization":"Bearer not-needed"}, timeout=120, stream=is_stream) lat = int((time.time()-start)*1000) + gpu_decr(model) # Release slot + if resp.status_code != 200: return jsonify({"error":"GPU error "+str(resp.status_code)}), 502 if is_stream: def gen(): @@ -160,10 +225,12 @@ def chat(): msg = c.get("message",{}) if not msg.get("content") and msg.get("reasoning_content"): msg["content"] = msg["reasoning_content"] - data["routing"] = {"model":model,"reason":reason,"gpu":url,"tier":tier,"agent":agent,"latency_ms":lat} + data["routing"] = {"model":model,"reason":reason,"gpu":url,"tier":tier,"agent":agent,"latency_ms":lat,"active_gpu":gpu_active_count(model)} bcast() return jsonify(data) - except requests.Timeout: return jsonify({"error":"timeout"}), 504 + except requests.Timeout: + gpu_decr(model if 'model' in dir() else "unknown") + return jsonify({"error":"timeout"}), 504 except Exception as e: log.error("Error: %s\n%s", e, traceback.format_exc()) return jsonify({"error":str(e)}), 500 @@ -172,7 +239,14 @@ def chat(): def models(): return jsonify({"object":"list","data":[{"id":m,"object":"model","owned_by":"syslog","status":check_gpu_health(m).get("status"),"gpu":check_gpu_health(m).get("gpu_name")} for m in GPU_URLS]}) @app.route("/health") -def health(): return jsonify({"status":"healthy","redis":"connected" if r else "down","gpus":{m:check_gpu_health(m) for m in GPU_URLS},"available_models":available_models()}) +def health(): + gpus = {} + for m in GPU_URLS: + h = check_gpu_health(m) + h["active_requests"] = gpu_active_count(m) + h["max_concurrent"] = GPU_MAX_CONCURRENT.get(m, 1) + gpus[m] = h + return jsonify({"status":"healthy","redis":"connected" if r else "down","gpus":gpus,"available_models":available_models()}) @app.route("/metrics") def metrics(): return jsonify(get_metrics()) @@ -220,5 +294,5 @@ def stream(): headers={"Cache-Control":"no-cache","X-Accel-Buffering":"no","Access-Control-Allow-Origin":"*"}) if __name__ == "__main__": - log.info("Router on :9000") + log.info("Router on :9000 (load-aware)") app.run(host="0.0.0.0", port=9000, debug=False)