Load-aware routing: tracks active GPU requests in Redis, distributes overflow when MoE saturated. 6 concurrent requests now spread across all 3 GPUs instead of queuing on one.
This commit is contained in:
+99
-25
@@ -17,6 +17,13 @@ GPU_URLS = {
|
||||
"qwen3.6-27B-code": GPU_DENSE_URL,
|
||||
"gemma-4-E4B": GPU_LIGHT_URL,
|
||||
}
|
||||
# Max concurrent requests per GPU (based on llama.cpp --parallel)
|
||||
GPU_MAX_CONCURRENT = {
|
||||
"qwen3.6-35B-A3B": 2, # 2 slots
|
||||
"qwen3.6-27B-code": 2, # 2 slots
|
||||
"gemma-4-E4B": 1, # 1 slot
|
||||
}
|
||||
|
||||
TIER_MODELS = {
|
||||
"starter": ["gemma-4-E4B"],
|
||||
"professional": ["qwen3.6-35B-A3B", "qwen3.6-27B-code", "gemma-4-E4B"],
|
||||
@@ -41,6 +48,18 @@ except Exception: r = None
|
||||
app = Flask(__name__)
|
||||
sse_subscribers = []; sse_lock = threading.Lock()
|
||||
|
||||
def gpu_active_count(model):
|
||||
"""Get number of in-flight requests for a GPU."""
|
||||
if r:
|
||||
return int(r.get("active:" + model) or 0)
|
||||
return 0
|
||||
|
||||
def gpu_incr(model):
|
||||
if r: r.incr("active:" + model)
|
||||
|
||||
def gpu_decr(model):
|
||||
if r: r.decr("active:" + model)
|
||||
|
||||
def check_gpu_health(model):
|
||||
url = GPU_SIDECARS.get(model)
|
||||
if not url: return {"status": "unknown"}
|
||||
@@ -57,6 +76,28 @@ def available_models(): return [m for m in GPU_URLS if check_gpu_health(m)["stat
|
||||
|
||||
def estimate_tokens(msgs): return sum(len(str(m.get("content",""))) for m in msgs) // 4
|
||||
|
||||
def is_gpu_busy(model):
|
||||
"""Check if GPU is at or near max concurrent capacity."""
|
||||
active = gpu_active_count(model)
|
||||
max_c = GPU_MAX_CONCURRENT.get(model, 1)
|
||||
return active >= max_c
|
||||
|
||||
def select_best_gpu(candidates, reason):
|
||||
"""Pick the best GPU from candidates, preferring least-loaded."""
|
||||
best = None
|
||||
best_load = 999
|
||||
for m in candidates:
|
||||
load = gpu_active_count(m)
|
||||
if load < best_load:
|
||||
best_load = load
|
||||
best = m
|
||||
if best:
|
||||
actual_reason = reason
|
||||
if is_gpu_busy(best):
|
||||
actual_reason = "load_balanced_" + reason
|
||||
return {"model": best, "reason": actual_reason}
|
||||
return None
|
||||
|
||||
def route(rd, tier):
|
||||
msgs = rd.get("messages",[]); t = estimate_tokens(msgs)
|
||||
sys = any(m.get("role")=="system" for m in msgs)
|
||||
@@ -65,34 +106,52 @@ def route(rd, tier):
|
||||
allowed = TIER_MODELS.get(tier, ["gemma-4-E4B"])
|
||||
avail = [m for m in available_models() if m in allowed]
|
||||
if not avail: return {"model": allowed[0], "reason": "all_saturated"}
|
||||
|
||||
req = rd.get("model","auto")
|
||||
if req != "auto": return {"model": req if req in avail else avail[0], "reason": "explicit"}
|
||||
if req != "auto":
|
||||
target = req if req in avail else avail[0]
|
||||
# If explicit model is busy, check if another can take it
|
||||
if is_gpu_busy(target) and req in allowed:
|
||||
alts = [m for m in avail if m != target and m in allowed]
|
||||
if alts:
|
||||
alt = select_best_gpu(alts, "explicit")
|
||||
if alt: return alt
|
||||
return {"model": target, "reason": "explicit"}
|
||||
|
||||
if hints:
|
||||
if hints.get("priority")=="speed" and "gemma-4-E4B" in avail: return {"model":"gemma-4-E4B","reason":"hint_speed"}
|
||||
if hints.get("priority")=="quality" and "qwen3.6-27B-code" in avail: return {"model":"qwen3.6-27B-code","reason":"hint_quality"}
|
||||
if hints.get("priority")=="speed" and "gemma-4-E4B" in avail:
|
||||
return select_best_gpu(["gemma-4-E4B"], "hint_speed") or {"model":"gemma-4-E4B","reason":"hint_speed"}
|
||||
if hints.get("priority")=="quality" and "qwen3.6-27B-code" in avail:
|
||||
return select_best_gpu(["qwen3.6-27B-code"], "hint_quality") or {"model":"qwen3.6-27B-code","reason":"hint_quality"}
|
||||
|
||||
# Heavy -> dense (but fall back to MoE if dense is busy)
|
||||
if t > 4000 or sys or turns > 6:
|
||||
for m in ["qwen3.6-27B-code","qwen3.6-35B-A3B","gemma-4-E4B"]:
|
||||
if m in avail: return {"model":m,"reason":"heavy_reasoning"}
|
||||
candidates = ["qwen3.6-27B-code","qwen3.6-35B-A3B","gemma-4-E4B"]
|
||||
candidates = [m for m in candidates if m in avail]
|
||||
result = select_best_gpu(candidates, "heavy_reasoning")
|
||||
if result: return result
|
||||
|
||||
# Ultra-light -> gemma
|
||||
first_msg = msgs[0].get("content","") if msgs else ""
|
||||
words = len(first_msg.split()) if isinstance(first_msg, str) else 99
|
||||
if words <= 3 and turns <= 1 and not sys and "gemma-4-E4B" in avail:
|
||||
if not is_gpu_busy("gemma-4-E4B"):
|
||||
return {"model":"gemma-4-E4B","reason":"ultra_light"}
|
||||
if "qwen3.6-35B-A3B" in avail: return {"model":"qwen3.6-35B-A3B","reason":"default_moe"}
|
||||
|
||||
# Default: MoE, fall back to dense if MoE is busy
|
||||
if "qwen3.6-35B-A3B" in avail:
|
||||
if is_gpu_busy("qwen3.6-35B-A3B") and "qwen3.6-27B-code" in avail:
|
||||
return {"model": "qwen3.6-27B-code", "reason": "load_balanced_default"}
|
||||
return {"model":"qwen3.6-35B-A3B","reason":"default_moe"}
|
||||
|
||||
return {"model":avail[0],"reason":"fallback"}
|
||||
|
||||
def clean_unicode(text):
|
||||
if not isinstance(text, str):
|
||||
return text
|
||||
# Replace common Unicode punctuation with ASCII equivalents
|
||||
text = text.replace(chr(0x2014), "-")
|
||||
text = text.replace(chr(0x2013), "-")
|
||||
text = text.replace(chr(0x2018), "'")
|
||||
text = text.replace(chr(0x2019), "'")
|
||||
text = text.replace(chr(0x201C), '"')
|
||||
text = text.replace(chr(0x201D), '"')
|
||||
text = text.replace(chr(0x2026), "...")
|
||||
text = text.replace(chr(0x00A0), " ")
|
||||
# Strip ALL remaining non-ASCII (emoji, symbols)
|
||||
if not isinstance(text, str): return text
|
||||
text = text.replace(chr(0x2014), "-"); text = text.replace(chr(0x2013), "-")
|
||||
text = text.replace(chr(0x2018), "'"); text = text.replace(chr(0x2019), "'")
|
||||
text = text.replace(chr(0x201C), '"'); text = text.replace(chr(0x201D), '"')
|
||||
text = text.replace(chr(0x2026), "..."); text = text.replace(chr(0x00A0), " ")
|
||||
return text.encode("ascii", "ignore").decode("ascii")
|
||||
|
||||
def clean_response(d):
|
||||
@@ -102,10 +161,11 @@ def clean_response(d):
|
||||
return d
|
||||
|
||||
def get_metrics():
|
||||
d = {"gpus":[],"route_counts":{},"agent_counts":{},"tier_counts":{},"recent":[],"timestamp":time.time()}
|
||||
d = {"gpus":[],"route_counts":{},"agent_counts":{},"tier_counts":{},"recent":[],"timestamp":time.time(),"active_requests":{}}
|
||||
for m in GPU_URLS:
|
||||
h = check_gpu_health(m)
|
||||
d["gpus"].append({"id":m,"gpu_name":h.get("gpu_name",m),"status":h.get("status"),"vram_used_mb":h.get("vram_used_mb"),"vram_total_mb":h.get("vram_total_mb"),"vram_pct":h.get("vram_pct"),"temp_c":h.get("temp_c"),"gpu_util_pct":h.get("gpu_util_pct"),"power_w":h.get("power_w"),"power_limit_w":h.get("power_limit_w")})
|
||||
d["gpus"].append({"id":m,"gpu_name":h.get("gpu_name",m),"status":h.get("status"),"vram_used_mb":h.get("vram_used_mb"),"vram_total_mb":h.get("vram_total_mb"),"vram_pct":h.get("vram_pct"),"temp_c":h.get("temp_c"),"gpu_util_pct":h.get("gpu_util_pct"),"power_w":h.get("power_w"),"power_limit_w":h.get("power_limit_w"),"active_requests":gpu_active_count(m)})
|
||||
d["active_requests"][m] = gpu_active_count(m)
|
||||
if r:
|
||||
try:
|
||||
for m in GPU_URLS: d["route_counts"][m] = int(r.get("routes:"+m) or 0)
|
||||
@@ -136,7 +196,10 @@ def chat():
|
||||
tier, agent = ki["tier"], ki["agent"]
|
||||
d = route(rd, tier); model, reason, url = d["model"], d["reason"], GPU_URLS[d["model"]]
|
||||
is_stream = rd.get("stream", False)
|
||||
log.info("ROUTE: %s -> %s (%s) stream=%s", agent, model, reason, is_stream)
|
||||
|
||||
gpu_incr(model) # Track active request
|
||||
|
||||
log.info("ROUTE: %s -> %s (%s) stream=%s active=%d/%d", agent, model, reason, is_stream, gpu_active_count(model), GPU_MAX_CONCURRENT.get(model,1))
|
||||
if r:
|
||||
try:
|
||||
r.incr("routes:"+model); r.incr("routes:tier:"+tier); r.incr("routes:agent:"+agent)
|
||||
@@ -148,6 +211,8 @@ def chat():
|
||||
resp = requests.post(url+"/chat/completions", json=rd,
|
||||
headers={"Content-Type":"application/json","Authorization":"Bearer not-needed"}, timeout=120, stream=is_stream)
|
||||
lat = int((time.time()-start)*1000)
|
||||
gpu_decr(model) # Release slot
|
||||
|
||||
if resp.status_code != 200: return jsonify({"error":"GPU error "+str(resp.status_code)}), 502
|
||||
if is_stream:
|
||||
def gen():
|
||||
@@ -160,10 +225,12 @@ def chat():
|
||||
msg = c.get("message",{})
|
||||
if not msg.get("content") and msg.get("reasoning_content"):
|
||||
msg["content"] = msg["reasoning_content"]
|
||||
data["routing"] = {"model":model,"reason":reason,"gpu":url,"tier":tier,"agent":agent,"latency_ms":lat}
|
||||
data["routing"] = {"model":model,"reason":reason,"gpu":url,"tier":tier,"agent":agent,"latency_ms":lat,"active_gpu":gpu_active_count(model)}
|
||||
bcast()
|
||||
return jsonify(data)
|
||||
except requests.Timeout: return jsonify({"error":"timeout"}), 504
|
||||
except requests.Timeout:
|
||||
gpu_decr(model if 'model' in dir() else "unknown")
|
||||
return jsonify({"error":"timeout"}), 504
|
||||
except Exception as e:
|
||||
log.error("Error: %s\n%s", e, traceback.format_exc())
|
||||
return jsonify({"error":str(e)}), 500
|
||||
@@ -172,7 +239,14 @@ def chat():
|
||||
def models(): return jsonify({"object":"list","data":[{"id":m,"object":"model","owned_by":"syslog","status":check_gpu_health(m).get("status"),"gpu":check_gpu_health(m).get("gpu_name")} for m in GPU_URLS]})
|
||||
|
||||
@app.route("/health")
|
||||
def health(): return jsonify({"status":"healthy","redis":"connected" if r else "down","gpus":{m:check_gpu_health(m) for m in GPU_URLS},"available_models":available_models()})
|
||||
def health():
|
||||
gpus = {}
|
||||
for m in GPU_URLS:
|
||||
h = check_gpu_health(m)
|
||||
h["active_requests"] = gpu_active_count(m)
|
||||
h["max_concurrent"] = GPU_MAX_CONCURRENT.get(m, 1)
|
||||
gpus[m] = h
|
||||
return jsonify({"status":"healthy","redis":"connected" if r else "down","gpus":gpus,"available_models":available_models()})
|
||||
|
||||
@app.route("/metrics")
|
||||
def metrics(): return jsonify(get_metrics())
|
||||
@@ -220,5 +294,5 @@ def stream():
|
||||
headers={"Cache-Control":"no-cache","X-Accel-Buffering":"no","Access-Control-Allow-Origin":"*"})
|
||||
|
||||
if __name__ == "__main__":
|
||||
log.info("Router on :9000")
|
||||
log.info("Router on :9000 (load-aware)")
|
||||
app.run(host="0.0.0.0", port=9000, debug=False)
|
||||
|
||||
Reference in New Issue
Block a user