feat: capture streaming token counts from SSE final chunk

Router now buffers streaming response chunks to extract timings
(prompt_n, predicted_n, predicted_per_second) from the final
SSE data frame before yielding to the client. Streaming requests
get real throughput data instead of 0 tok/s.

Uses llama.cpp timings field in the last content chunk:
- completion_tokens = predicted_n
- tokens_per_sec = predicted_per_second
- inference_ms = predicted_ms (generation only)

Client sees identical stream, no perceptible delay.
This commit is contained in:
Abiba
2026-05-25 19:58:51 +00:00
parent b2ec4b0572
commit cfb05fa501
+26 -3
View File
@@ -352,11 +352,34 @@ def chat():
if resp.status_code != 200: return jsonify({"error":"GPU error "+str(resp.status_code)}), 502
if is_stream:
def gen():
# Buffer stream to capture timings from final SSE chunk
chunks = []
stream_timings = {}
for raw in resp.iter_content(chunk_size=None, decode_unicode=True):
if raw: yield clean_unicode(raw)
# Streaming: can't get token counts without parsing stream, store latency + estimated tokens
if raw:
cleaned = clean_unicode(raw)
chunks.append(cleaned)
# Parse last content chunk (before [DONE]) for timings
if not stream_timings and '"timings"' in cleaned and '"predicted_n"' in cleaned:
try:
json_str = cleaned.replace("data: ", "").strip()
if json_str.startswith("{"):
tj = json.loads(json_str).get("timings", {})
if tj:
stream_timings = tj
except: pass
# Store perf record with real token counts from stream
if stream_timings:
pt = stream_timings.get("prompt_n", 0)
ct = stream_timings.get("predicted_n", 0)
tps = stream_timings.get("predicted_per_second", 0)
gen_ms = stream_timings.get("predicted_ms", lat)
store_perf_record(model, agent, tier, reason, queue_ms, gen_ms, pt, ct, True)
else:
store_perf_record(model, agent, tier, reason, queue_ms, lat, estimate_tokens(rd.get("messages",[])), 0, True)
# Yield all chunks to client
def gen():
for c in chunks: yield c
bcast()
ctx_remaining = GPU_CONTEXT.get(model, 65536) - max(session_tokens, estimate_tokens(rd.get("messages",[])))
ctx_pct = ctx_remaining / GPU_CONTEXT.get(model, 65536) * 100