Files
motia-iii/steps/ai/chat_completions_api_step.py
bsiggel 69f0c6a44d feat: Implement AI Chat Completions API with streaming support and models list endpoint
- Enhanced the AI Chat Completions API to support true streaming using async generators and proper SSE headers.
- Updated endpoint paths to align with OpenAI's API versioning.
- Improved logging for request details and error handling.
- Added a new AI Models List API to return available models compatible with chat completions.
- Refactored code for better readability and maintainability, including the extraction of common functionalities.
- Introduced a VMH-specific Chat Completions API with similar features and structure.
2026-03-18 21:30:59 +00:00

281 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI Chat Completions API
Universal OpenAI-compatible Chat Completions API with xAI/LangChain Backend.
Features:
- File Search (RAG) via xAI Collections
- Web Search via xAI web_search tool
- Aktenzeichen-based automatic collection lookup
- **Echtes Streaming** (async generator + proper SSE headers)
- Multiple tools simultaneously
"""
import json
import time
from typing import Any, Dict, List, Optional
from motia import FlowContext, http, ApiRequest, ApiResponse
config = {
"name": "AI Chat Completions API",
"description": "Universal OpenAI-compatible Chat Completions API with xAI backend, RAG, and web search",
"flows": ["ai-general"],
"triggers": [
http("POST", "/ai/v1/chat/completions"),
http("POST", "/v1/chat/completions")
],
}
async def handler(request: ApiRequest, ctx: FlowContext[Any]) -> ApiResponse:
"""
OpenAI-compatible Chat Completions endpoint mit **echtem** Streaming.
"""
ctx.logger.info("=" * 80)
ctx.logger.info("🤖 AI CHAT COMPLETIONS API OPTIMIZED")
ctx.logger.info("=" * 80)
# Log request (sicher)
ctx.logger.info("📥 REQUEST DETAILS:")
if request.headers:
ctx.logger.info(" Headers:")
for header_name, header_value in request.headers.items():
if header_name.lower() == 'authorization':
ctx.logger.info(f" {header_name}: Bearer ***MASKED***")
else:
ctx.logger.info(f" {header_name}: {header_value}")
try:
# Parse body
body = request.body or {}
if not isinstance(body, dict):
return ApiResponse(status=400, body={'error': 'Request body must be JSON object'})
# Parameter extrahieren
model_name = body.get('model', 'grok-4.20-beta-0309-reasoning')
messages = body.get('messages', [])
temperature = body.get('temperature', 0.7)
max_tokens = body.get('max_tokens')
stream = body.get('stream', False)
extra_body = body.get('extra_body', {})
enable_web_search = body.get('enable_web_search', extra_body.get('enable_web_search', False))
web_search_config = body.get('web_search_config', extra_body.get('web_search_config', {}))
ctx.logger.info(f"📋 Model: {model_name} | Stream: {stream} | Web Search: {enable_web_search}")
# Messages loggen (kurz)
ctx.logger.info("📨 MESSAGES:")
for i, msg in enumerate(messages, 1):
preview = (msg.get('content', '')[:120] + "...") if len(msg.get('content', '')) > 120 else msg.get('content', '')
ctx.logger.info(f" [{i}] {msg.get('role')}: {preview}")
# === Collection + Aktenzeichen Logic (unverändert) ===
collection_id: Optional[str] = None
aktenzeichen: Optional[str] = None
if 'collection_id' in body:
collection_id = body['collection_id']
elif 'custom_collection_id' in body:
collection_id = body['custom_collection_id']
elif 'collection_id' in extra_body:
collection_id = extra_body['collection_id']
else:
for msg in messages:
if msg.get('role') == 'user':
content = msg.get('content', '')
from services.aktenzeichen_utils import extract_aktenzeichen, normalize_aktenzeichen, remove_aktenzeichen
aktenzeichen_raw = extract_aktenzeichen(content)
if aktenzeichen_raw:
aktenzeichen = normalize_aktenzeichen(aktenzeichen_raw)
collection_id = await lookup_collection_by_aktenzeichen(aktenzeichen, ctx)
if collection_id:
msg['content'] = remove_aktenzeichen(content)
break
if not collection_id and not enable_web_search:
return ApiResponse(
status=400,
body={'error': 'collection_id or web_search required'}
)
# === Service initialisieren ===
from services.langchain_xai_service import LangChainXAIService
langchain_service = LangChainXAIService(ctx)
model = langchain_service.get_chat_model(
model=model_name,
temperature=temperature,
max_tokens=max_tokens
)
model_with_tools = langchain_service.bind_tools(
model=model,
collection_id=collection_id,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
max_num_results=10
)
completion_id = f"chatcmpl-{ctx.traceId[:12]}" if hasattr(ctx, 'traceId') else f"chatcmpl-{int(time.time())}"
created_ts = int(time.time())
# ====================== ECHTES STREAMING ======================
if stream:
ctx.logger.info("🌊 Starting REAL SSE streaming (async generator)...")
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # nginx / proxies
"Transfer-Encoding": "chunked",
}
async def sse_generator():
# Initial chunk (manche Clients brauchen das)
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": model_name, "choices": [{"index": 0, "delta": {}, "finish_reason": None}]}, ensure_ascii=False)}\n\n'
chunk_count = 0
async for chunk in langchain_service.astream_chat(model_with_tools, messages):
delta = ""
if hasattr(chunk, "content"):
content = chunk.content
if isinstance(content, str):
delta = content
elif isinstance(content, list):
text_parts = [item.get('text', '') for item in content if isinstance(item, dict) and item.get('type') == 'text']
delta = ''.join(text_parts)
if delta:
chunk_count += 1
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": None
}]
}
yield f'data: {json.dumps(data, ensure_ascii=False)}\n\n'
# Finish
finish = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f'data: {json.dumps(finish, ensure_ascii=False)}\n\n'
yield "data: [DONE]\n\n"
ctx.logger.info(f"✅ Streaming abgeschlossen {chunk_count} Chunks gesendet")
return ApiResponse(
status=200,
headers=headers,
body=sse_generator() # ← async generator = echtes Streaming!
)
# ====================== NON-STREAMING (unverändert + optimiert) ======================
else:
return await handle_non_streaming_response(
model_with_tools=model_with_tools,
messages=messages,
completion_id=completion_id,
created_ts=created_ts,
model_name=model_name,
langchain_service=langchain_service,
ctx=ctx
)
except Exception as e:
ctx.logger.error(f"❌ ERROR: {e}", exc_info=True)
return ApiResponse(
status=500,
body={'error': 'Internal server error', 'message': str(e)}
)
async def handle_non_streaming_response(
model_with_tools,
messages: List[Dict[str, Any]],
completion_id: str,
created_ts: int,
model_name: str,
langchain_service,
ctx: FlowContext
) -> ApiResponse:
"""Non-Streaming Handler (optimiert)."""
try:
result = await langchain_service.invoke_chat(model_with_tools, messages)
# Content extrahieren (kompatibel mit xAI structured output)
if hasattr(result, 'content'):
raw = result.content
if isinstance(raw, list):
text_parts = [item.get('text', '') for item in raw if isinstance(item, dict) and item.get('type') == 'text']
content = ''.join(text_parts) or str(raw)
else:
content = raw
else:
content = str(result)
# Usage (falls verfügbar)
usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
if hasattr(result, 'usage_metadata'):
u = result.usage_metadata
usage = {
"prompt_tokens": getattr(u, 'input_tokens', 0),
"completion_tokens": getattr(u, 'output_tokens', 0),
"total_tokens": getattr(u, 'input_tokens', 0) + getattr(u, 'output_tokens', 0)
}
response_body = {
'id': completion_id,
'object': 'chat.completion',
'created': created_ts,
'model': model_name,
'choices': [{
'index': 0,
'message': {'role': 'assistant', 'content': content},
'finish_reason': 'stop'
}],
'usage': usage
}
ctx.logger.info(f"✅ Non-streaming fertig {len(content)} Zeichen")
return ApiResponse(status=200, body=response_body)
except Exception as e:
ctx.logger.error(f"❌ Non-streaming failed: {e}")
raise
async def lookup_collection_by_aktenzeichen(aktenzeichen: str, ctx: FlowContext) -> Optional[str]:
"""Aktenzeichen → Collection Lookup (unverändert)."""
try:
from services.espocrm import EspoCRMAPI
espocrm = EspoCRMAPI(ctx)
ctx.logger.info(f"🔍 Suche Räumungsklage für Aktenzeichen: {aktenzeichen}")
search_result = await espocrm.search_entities(
entity_type='Raeumungsklage',
where=[{'type': 'equals', 'attribute': 'advowareAkteBezeichner', 'value': aktenzeichen}],
select=['id', 'xaiCollectionId'],
maxSize=1
)
if search_result and len(search_result) > 0:
collection_id = search_result[0].get('xaiCollectionId')
if collection_id:
ctx.logger.info(f"✅ Collection gefunden: {collection_id}")
return collection_id
return None
except Exception as e:
ctx.logger.error(f"❌ Lookup failed: {e}")
return None