Files
motia-iii/steps/ai/chat_completions_api_step.py

386 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI Chat Completions API
OpenAI-compatible Chat Completions endpoint with xAI/LangChain backend.
Features:
- File Search (RAG) via xAI Collections
- Web Search via xAI web_search tool
- Aktenzeichen-based automatic collection lookup
- Multiple tools simultaneously
- Clean, reusable architecture for future LLM endpoints
Note: Streaming is not supported (Motia limitation - returns clear error).
Reusability:
- extract_request_params(): Parse requests for any LLM endpoint
- resolve_collection_id(): Auto-detect Aktenzeichen, lookup collection
- initialize_model_with_tools(): Bind tools to any LangChain model
- invoke_and_format_response(): Standard OpenAI response formatting
"""
import time
from typing import Any, Dict, List, Optional
from motia import FlowContext, http, ApiRequest, ApiResponse
config = {
"name": "AI Chat Completions API",
"description": "OpenAI-compatible Chat Completions API with xAI backend",
"flows": ["ai-general"],
"triggers": [
http("POST", "/ai/v1/chat/completions"),
http("POST", "/v1/chat/completions")
],
}
# ============================================================================
# MAIN HANDLER
# ============================================================================
async def handler(request: ApiRequest, ctx: FlowContext[Any]) -> ApiResponse:
"""
OpenAI-compatible Chat Completions endpoint.
Returns:
ApiResponse with chat completion or error
"""
ctx.logger.info("=" * 80)
ctx.logger.info("🤖 AI Chat Completions API")
ctx.logger.info("=" * 80)
try:
# 1. Parse and validate request
params = extract_request_params(request, ctx)
# 2. Check streaming (not supported)
if params['stream']:
return ApiResponse(
status=501,
body={
'error': {
'message': 'Streaming is not supported. Please set stream=false.',
'type': 'not_implemented',
'param': 'stream'
}
}
)
# 3. Resolve collection (explicit ID or Aktenzeichen lookup)
collection_id = await resolve_collection_id(
params['collection_id'],
params['messages'],
params['enable_web_search'],
ctx
)
# 4. Validate: collection or web_search required
if not collection_id and not params['enable_web_search']:
return ApiResponse(
status=400,
body={
'error': {
'message': 'Either collection_id or enable_web_search must be provided',
'type': 'invalid_request_error'
}
}
)
# 5. Initialize LLM with tools
model_with_tools = await initialize_model_with_tools(
model_name=params['model'],
temperature=params['temperature'],
max_tokens=params['max_tokens'],
collection_id=collection_id,
enable_web_search=params['enable_web_search'],
web_search_config=params['web_search_config'],
ctx=ctx
)
# 6. Invoke LLM
completion_id = f"chatcmpl-{int(time.time())}"
response = await invoke_and_format_response(
model=model_with_tools,
messages=params['messages'],
completion_id=completion_id,
model_name=params['model'],
ctx=ctx
)
ctx.logger.info(f"✅ Completion successful {len(response.body['choices'][0]['message']['content'])} chars")
return response
except ValueError as e:
ctx.logger.error(f"❌ Validation error: {e}")
return ApiResponse(
status=400,
body={'error': {'message': str(e), 'type': 'invalid_request_error'}}
)
except Exception as e:
ctx.logger.error(f"❌ Error: {e}")
return ApiResponse(
status=500,
body={'error': {'message': 'Internal server error', 'type': 'server_error'}}
)
# ============================================================================
# REUSABLE HELPER FUNCTIONS
# ============================================================================
def extract_request_params(request: ApiRequest, ctx: FlowContext) -> Dict[str, Any]:
"""
Extract and validate request parameters.
Returns:
Dict with validated parameters
Raises:
ValueError: If validation fails
"""
body = request.body or {}
if not isinstance(body, dict):
raise ValueError("Request body must be JSON object")
messages = body.get('messages', [])
if not messages or not isinstance(messages, list):
raise ValueError("messages must be non-empty array")
# Extract parameters with defaults
params = {
'model': body.get('model', 'grok-4-1-fast-reasoning'),
'messages': messages,
'temperature': body.get('temperature', 0.7),
'max_tokens': body.get('max_tokens'),
'stream': body.get('stream', False),
'extra_body': body.get('extra_body', {}),
}
# Handle enable_web_search (body or extra_body)
params['enable_web_search'] = body.get(
'enable_web_search',
params['extra_body'].get('enable_web_search', False)
)
# Handle web_search_config
params['web_search_config'] = body.get(
'web_search_config',
params['extra_body'].get('web_search_config', {})
)
# Handle collection_id (multiple sources)
params['collection_id'] = (
body.get('collection_id') or
body.get('custom_collection_id') or
params['extra_body'].get('collection_id')
)
# Log concisely
ctx.logger.info(f"📋 Model: {params['model']} | Stream: {params['stream']}")
ctx.logger.info(f"📋 Web Search: {params['enable_web_search']} | Collection: {params['collection_id'] or 'auto'}")
ctx.logger.info(f"📨 Messages: {len(messages)}")
return params
async def resolve_collection_id(
explicit_collection_id: Optional[str],
messages: List[Dict[str, Any]],
enable_web_search: bool,
ctx: FlowContext
) -> Optional[str]:
"""
Resolve collection ID from explicit ID or Aktenzeichen auto-detection.
Args:
explicit_collection_id: Explicitly provided collection ID
messages: Chat messages (for Aktenzeichen extraction)
enable_web_search: Whether web search is enabled
ctx: Motia context
Returns:
Collection ID or None
"""
# Explicit collection ID takes precedence
if explicit_collection_id:
ctx.logger.info(f"🔍 Using explicit collection: {explicit_collection_id}")
return explicit_collection_id
# Try Aktenzeichen auto-detection from first user message
from services.aktenzeichen_utils import (
extract_aktenzeichen,
normalize_aktenzeichen,
remove_aktenzeichen
)
for msg in messages:
if msg.get('role') == 'user':
content = msg.get('content', '')
aktenzeichen_raw = extract_aktenzeichen(content)
if aktenzeichen_raw:
aktenzeichen = normalize_aktenzeichen(aktenzeichen_raw)
ctx.logger.info(f"🔍 Aktenzeichen detected: {aktenzeichen}")
collection_id = await lookup_collection_by_aktenzeichen(aktenzeichen, ctx)
if collection_id:
# Clean Aktenzeichen from message
msg['content'] = remove_aktenzeichen(content)
ctx.logger.info(f"✅ Collection found: {collection_id}")
return collection_id
else:
ctx.logger.warning(f"⚠️ No collection for Aktenzeichen: {aktenzeichen}")
break # Only check first user message
return None
async def initialize_model_with_tools(
model_name: str,
temperature: float,
max_tokens: Optional[int],
collection_id: Optional[str],
enable_web_search: bool,
web_search_config: Dict[str, Any],
ctx: FlowContext
) -> Any:
"""
Initialize LangChain model with tool bindings (file_search, web_search).
Returns:
Model instance with tools bound
"""
from services.langchain_xai_service import LangChainXAIService
service = LangChainXAIService(ctx)
# Create base model
model = service.get_chat_model(
model=model_name,
temperature=temperature,
max_tokens=max_tokens
)
# Bind tools
model_with_tools = service.bind_tools(
model=model,
collection_id=collection_id,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
max_num_results=10
)
return model_with_tools
async def invoke_and_format_response(
model: Any,
messages: List[Dict[str, Any]],
completion_id: str,
model_name: str,
ctx: FlowContext
) -> ApiResponse:
"""
Invoke LLM and format response in OpenAI-compatible format.
Returns:
ApiResponse with chat completion
"""
from services.langchain_xai_service import LangChainXAIService
service = LangChainXAIService(ctx)
result = await service.invoke_chat(model, messages)
# Extract content (handle structured responses)
if hasattr(result, 'content'):
raw = result.content
if isinstance(raw, list):
# Extract text parts from structured response
text_parts = [
item.get('text', '')
for item in raw
if isinstance(item, dict) and item.get('type') == 'text'
]
content = ''.join(text_parts) or str(raw)
else:
content = raw
else:
content = str(result)
# Extract usage metadata (if available)
usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
if hasattr(result, 'usage_metadata'):
u = result.usage_metadata
usage = {
"prompt_tokens": getattr(u, 'input_tokens', 0),
"completion_tokens": getattr(u, 'output_tokens', 0),
"total_tokens": getattr(u, 'input_tokens', 0) + getattr(u, 'output_tokens', 0)
}
# Log complete LLM response
ctx.logger.info("=" * 80)
ctx.logger.info("📤 LLM RESPONSE")
ctx.logger.info("-" * 80)
ctx.logger.info(f"Model: {model_name}")
ctx.logger.info(f"Completion ID: {completion_id}")
ctx.logger.info(f"Usage: {usage['prompt_tokens']} prompt + {usage['completion_tokens']} completion = {usage['total_tokens']} total tokens")
ctx.logger.info("-" * 80)
ctx.logger.info("Content:")
ctx.logger.info(content)
ctx.logger.info("=" * 80)
# Format OpenAI-compatible response
response_body = {
'id': completion_id,
'object': 'chat.completion',
'created': int(time.time()),
'model': model_name,
'choices': [{
'index': 0,
'message': {'role': 'assistant', 'content': content},
'finish_reason': 'stop'
}],
'usage': usage
}
return ApiResponse(status=200, body=response_body)
async def lookup_collection_by_aktenzeichen(
aktenzeichen: str,
ctx: FlowContext
) -> Optional[str]:
"""
Lookup xAI Collection ID by Aktenzeichen via EspoCRM.
Args:
aktenzeichen: Normalized Aktenzeichen (e.g., "1234/56")
ctx: Motia context
Returns:
Collection ID or None if not found
"""
try:
from services.espocrm import EspoCRMAPI
espocrm = EspoCRMAPI(ctx)
search_result = await espocrm.search_entities(
entity_type='Raeumungsklage',
where=[{
'type': 'equals',
'attribute': 'advowareAkteBezeichner',
'value': aktenzeichen
}],
select=['id', 'xaiCollectionId'],
maxSize=1
)
if search_result and len(search_result) > 0:
return search_result[0].get('xaiCollectionId')
return None
except Exception as e:
ctx.logger.error(f"❌ Collection lookup failed: {e}")
return None