Files
motia-iii/steps/ai/chat_completions_api_step.py

530 lines
20 KiB
Python

"""AI Chat Completions API
Universal OpenAI-compatible Chat Completions API with xAI/LangChain Backend.
Features:
- File Search (RAG) via xAI Collections
- Web Search via xAI web_search tool
- Aktenzeichen-based automatic collection lookup
- Streaming & Non-Streaming support
- Multiple tools simultaneously (file_search + web_search)
"""
import json
import time
from typing import Any, Dict, List, Optional
from motia import FlowContext, http, ApiRequest, ApiResponse
config = {
"name": "AI Chat Completions API",
"description": "Universal OpenAI-compatible Chat Completions API with xAI backend, RAG, and web search",
"flows": ["ai-general"],
"triggers": [
http("POST", "/ai/chat/completions")
],
}
async def handler(request: ApiRequest, ctx: FlowContext[Any]) -> ApiResponse:
"""
OpenAI-compatible Chat Completions endpoint.
Request Body (OpenAI format):
{
"model": "grok-4.20-beta-0309-reasoning",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "1234/56 Was ist der Stand?"}
],
"temperature": 0.7,
"max_tokens": 2000,
"stream": false,
"extra_body": {
"collection_id": "col_abc123", // Optional: override auto-detection
"enable_web_search": true, // Optional: enable web search (default: false)
"web_search_config": { // Optional: web search configuration
"allowed_domains": ["example.com"],
"excluded_domains": ["spam.com"],
"enable_image_understanding": true
}
}
}
Aktenzeichen-Erkennung (Priority):
1. extra_body.collection_id (explicit override)
2. First user message starts with Aktenzeichen (e.g., "1234/56 ...")
3. Web-only mode if no collection_id (must enable_web_search)
Response (OpenAI format):
Non-Streaming:
{
"id": "chatcmpl-...",
"object": "chat.completion",
"created": 1234567890,
"model": "grok-4.20-beta-0309-reasoning",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "..."},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
}
Streaming (SSE):
data: {"id":"chatcmpl-...","choices":[{"delta":{"content":"Hello"},...}]}
data: {"id":"chatcmpl-...","choices":[{"delta":{"content":" world"},...}]}
data: {"choices":[{"delta":{},"finish_reason":"stop"}]}
data: [DONE]
"""
from services.langchain_xai_service import LangChainXAIService
from services.aktenzeichen_utils import extract_aktenzeichen, normalize_aktenzeichen
from services.espocrm import EspoCRMAPI
ctx.logger.info("=" * 80)
ctx.logger.info("🤖 AI CHAT COMPLETIONS API")
ctx.logger.info("=" * 80)
try:
# Parse request body
body = request.body or {}
if not isinstance(body, dict):
ctx.logger.error(f"❌ Invalid request body type: {type(body)}")
return ApiResponse(
status=400,
body={'error': 'Request body must be JSON object'}
)
# Extract parameters
model_name = body.get('model', 'grok-4.20-beta-0309-reasoning')
messages = body.get('messages', [])
temperature = body.get('temperature', 0.7)
max_tokens = body.get('max_tokens')
stream = body.get('stream', False)
extra_body = body.get('extra_body', {})
# Web Search parameters (default: disabled)
enable_web_search = extra_body.get('enable_web_search', False)
web_search_config = extra_body.get('web_search_config', {})
ctx.logger.info(f"📋 Model: {model_name}")
ctx.logger.info(f"📋 Messages: {len(messages)}")
ctx.logger.info(f"📋 Stream: {stream}")
ctx.logger.info(f"📋 Web Search: {'enabled' if enable_web_search else 'disabled'}")
if enable_web_search and web_search_config:
ctx.logger.debug(f"Web Search Config: {json.dumps(web_search_config, indent=2)}")
# Log full conversation messages
ctx.logger.info("-" * 80)
ctx.logger.info("📨 REQUEST MESSAGES:")
for i, msg in enumerate(messages, 1):
role = msg.get('role', 'unknown')
content = msg.get('content', '')
preview = content[:150] + "..." if len(content) > 150 else content
ctx.logger.info(f" [{i}] {role}: {preview}")
ctx.logger.info("-" * 80)
# Validate messages
if not messages or not isinstance(messages, list):
ctx.logger.error("❌ Missing or invalid messages array")
return ApiResponse(
status=400,
body={'error': 'messages must be non-empty array'}
)
# Determine collection_id (Priority: extra_body > Aktenzeichen > optional for web-only)
collection_id: Optional[str] = None
aktenzeichen: Optional[str] = None
# Priority 1: Explicit collection_id in extra_body
if 'collection_id' in extra_body:
collection_id = extra_body['collection_id']
ctx.logger.info(f"🔍 Collection ID from extra_body: {collection_id}")
# Priority 2: Extract Aktenzeichen from first user message
else:
for msg in messages:
if msg.get('role') == 'user':
content = msg.get('content', '')
aktenzeichen_raw = extract_aktenzeichen(content)
if aktenzeichen_raw:
aktenzeichen = normalize_aktenzeichen(aktenzeichen_raw)
ctx.logger.info(f"🔍 Aktenzeichen detected: {aktenzeichen}")
# Lookup collection_id via EspoCRM
collection_id = await lookup_collection_by_aktenzeichen(
aktenzeichen, ctx
)
if collection_id:
ctx.logger.info(f"✅ Collection found: {collection_id}")
# Remove Aktenzeichen from message (clean prompt)
from services.aktenzeichen_utils import remove_aktenzeichen
msg['content'] = remove_aktenzeichen(content)
ctx.logger.debug(f"Cleaned message: {msg['content']}")
else:
ctx.logger.warn(f"⚠️ No collection found for {aktenzeichen}")
break # Only check first user message
# Priority 3: Error if no collection_id AND web_search disabled
if not collection_id and not enable_web_search:
ctx.logger.error("❌ No collection_id found and web_search disabled")
ctx.logger.error(" Provide collection_id, enable web_search, or both")
return ApiResponse(
status=400,
body={
'error': 'collection_id or web_search required',
'message': 'Provide collection_id in extra_body, enable web_search, or start message with Aktenzeichen (e.g., "1234/56 question")'
}
)
# Initialize LangChain xAI Service
try:
langchain_service = LangChainXAIService(ctx)
except ValueError as e:
ctx.logger.error(f"❌ Service initialization failed: {e}")
return ApiResponse(
status=500,
body={'error': 'Service configuration error', 'details': str(e)}
)
# Create ChatXAI model
model = langchain_service.get_chat_model(
model=model_name,
temperature=temperature,
max_tokens=max_tokens
)
# Bind tools (file_search and/or web_search)
model_with_tools = langchain_service.bind_tools(
model=model,
collection_id=collection_id,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
max_num_results=10
)
# Generate completion_id
completion_id = f"chatcmpl-{ctx.traceId[:12]}" if hasattr(ctx, 'traceId') else f"chatcmpl-{int(time.time())}"
created_ts = int(time.time())
# Branch: Streaming vs Non-Streaming
if stream:
ctx.logger.info("🌊 Starting streaming response...")
return await handle_streaming_response(
model_with_tools=model_with_tools,
messages=messages,
completion_id=completion_id,
created_ts=created_ts,
model_name=model_name,
langchain_service=langchain_service,
ctx=ctx
)
else:
ctx.logger.info("📦 Starting non-streaming response...")
return await handle_non_streaming_response(
model_with_tools=model_with_tools,
messages=messages,
completion_id=completion_id,
created_ts=created_ts,
model_name=model_name,
langchain_service=langchain_service,
ctx=ctx
)
except Exception as e:
ctx.logger.error("=" * 80)
ctx.logger.error("❌ ERROR: AI CHAT COMPLETIONS API")
ctx.logger.error("=" * 80)
ctx.logger.error(f"Error: {e}", exc_info=True)
ctx.logger.error(f"Request body: {json.dumps(request.body, indent=2, ensure_ascii=False)}")
ctx.logger.error("=" * 80)
return ApiResponse(
status=500,
body={
'error': 'Internal server error',
'message': str(e)
}
)
async def handle_non_streaming_response(
model_with_tools,
messages: List[Dict[str, Any]],
completion_id: str,
created_ts: int,
model_name: str,
langchain_service,
ctx: FlowContext
) -> ApiResponse:
"""
Handle non-streaming chat completion.
Returns:
ApiResponse with OpenAI-format JSON body
"""
try:
# Invoke model
result = await langchain_service.invoke_chat(model_with_tools, messages)
# Extract content - handle both string and structured responses
if hasattr(result, 'content'):
raw_content = result.content
# If content is a list (tool calls + text message), extract text
if isinstance(raw_content, list):
# Find the text message (usually last element with type='text')
text_messages = [
item.get('text', '')
for item in raw_content
if isinstance(item, dict) and item.get('type') == 'text'
]
content = text_messages[0] if text_messages else str(raw_content)
else:
content = raw_content
else:
content = str(result)
# Build OpenAI-compatible response
response_body = {
'id': completion_id,
'object': 'chat.completion',
'created': created_ts,
'model': model_name,
'choices': [{
'index': 0,
'message': {
'role': 'assistant',
'content': content
},
'finish_reason': 'stop'
}],
'usage': {
'prompt_tokens': 0, # LangChain doesn't expose token counts easily
'completion_tokens': 0,
'total_tokens': 0
}
}
# Log token usage (if available)
if hasattr(result, 'usage_metadata'):
usage = result.usage_metadata
prompt_tokens = getattr(usage, 'input_tokens', 0)
completion_tokens = getattr(usage, 'output_tokens', 0)
response_body['usage'] = {
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens,
'total_tokens': prompt_tokens + completion_tokens
}
ctx.logger.info(f"📊 Token Usage: prompt={prompt_tokens}, completion={completion_tokens}")
# Log citations if available (from tool response annotations)
if hasattr(result, 'content') and isinstance(result.content, list):
# Extract citations from structured response
for item in result.content:
if isinstance(item, dict) and item.get('type') == 'text':
annotations = item.get('annotations', [])
if annotations:
ctx.logger.info(f"🔗 Citations: {len(annotations)}")
for i, citation in enumerate(annotations[:10], 1): # Log first 10
url = citation.get('url', 'N/A')
title = citation.get('title', '')
if url.startswith('collections://'):
# Internal collection reference
ctx.logger.debug(f" [{i}] Collection Document: {title}")
else:
# External URL
ctx.logger.debug(f" [{i}] {url}")
# Log complete response content
ctx.logger.info(f"✅ Chat completion: {len(content)} chars")
ctx.logger.info("=" * 80)
ctx.logger.info("📝 COMPLETE RESPONSE:")
ctx.logger.info("-" * 80)
ctx.logger.info(content)
ctx.logger.info("-" * 80)
ctx.logger.info("=" * 80)
return ApiResponse(
status=200,
body=response_body
)
except Exception as e:
ctx.logger.error(f"❌ Non-streaming completion failed: {e}", exc_info=True)
raise
async def handle_streaming_response(
model_with_tools,
messages: List[Dict[str, Any]],
completion_id: str,
created_ts: int,
model_name: str,
langchain_service,
ctx: FlowContext
):
"""
Handle streaming chat completion via SSE.
Returns:
Streaming response generator
"""
async def stream_generator():
try:
# Set SSE headers
await ctx.response.status(200)
await ctx.response.headers({
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
})
ctx.logger.info("🌊 Streaming started")
# Stream chunks
chunk_count = 0
total_content = ""
async for chunk in langchain_service.astream_chat(model_with_tools, messages):
# Extract delta content - handle structured chunks
if hasattr(chunk, "content"):
chunk_content = chunk.content
# If chunk content is a list (tool calls), extract text parts
if isinstance(chunk_content, list):
# Accumulate only text deltas
text_parts = [
item.get('text', '')
for item in chunk_content
if isinstance(item, dict) and item.get('type') == 'text'
]
delta = ''.join(text_parts)
else:
delta = chunk_content
else:
delta = ""
if delta:
total_content += delta
chunk_count += 1
# Build SSE data
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": None
}]
}
# Send SSE event
await ctx.response.stream(f"data: {json.dumps(data, ensure_ascii=False)}\n\n")
# Send finish event
finish_data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
await ctx.response.stream(f"data: {json.dumps(finish_data)}\n\n")
# Send [DONE]
await ctx.response.stream("data: [DONE]\n\n")
# Close stream
await ctx.response.close()
# Log complete streamed response
ctx.logger.info(f"✅ Streaming completed: {chunk_count} chunks, {len(total_content)} chars")
ctx.logger.info("=" * 80)
ctx.logger.info("📝 COMPLETE STREAMED RESPONSE:")
ctx.logger.info("-" * 80)
ctx.logger.info(total_content)
ctx.logger.info("-" * 80)
ctx.logger.info("=" * 80)
except Exception as e:
ctx.logger.error(f"❌ Streaming failed: {e}", exc_info=True)
# Send error event
error_data = {
"error": {
"message": str(e),
"type": "server_error"
}
}
await ctx.response.stream(f"data: {json.dumps(error_data)}\n\n")
await ctx.response.close()
return stream_generator()
async def lookup_collection_by_aktenzeichen(
aktenzeichen: str,
ctx: FlowContext
) -> Optional[str]:
"""
Lookup xAI Collection ID for Aktenzeichen via EspoCRM.
Search strategy:
1. Search for Raeumungsklage with matching advowareAkteBezeichner
2. Return xaiCollectionId if found
Args:
aktenzeichen: Normalized Aktenzeichen (e.g., "1234/56")
ctx: Motia context
Returns:
Collection ID or None if not found
"""
try:
# Initialize EspoCRM API
espocrm = EspoCRMAPI(ctx)
# Search Räumungsklage by advowareAkteBezeichner
ctx.logger.info(f"🔍 Searching Räumungsklage for Aktenzeichen: {aktenzeichen}")
search_result = await espocrm.search_entities(
entity_type='Raeumungsklage',
where=[{
'type': 'equals',
'attribute': 'advowareAkteBezeichner',
'value': aktenzeichen
}],
select=['id', 'xaiCollectionId', 'advowareAkteBezeichner'],
maxSize=1
)
if search_result and len(search_result) > 0:
entity = search_result[0]
collection_id = entity.get('xaiCollectionId')
if collection_id:
ctx.logger.info(f"✅ Found Räumungsklage: {entity.get('id')}")
return collection_id
else:
ctx.logger.warn(f"⚠️ Räumungsklage found but no xaiCollectionId: {entity.get('id')}")
else:
ctx.logger.warn(f"⚠️ No Räumungsklage found for {aktenzeichen}")
return None
except Exception as e:
ctx.logger.error(f"❌ Collection lookup failed: {e}", exc_info=True)
return None