219 lines
7.2 KiB
Python
219 lines
7.2 KiB
Python
"""LangChain xAI Integration Service
|
|
|
|
Service für LangChain ChatXAI Integration mit File Search Binding.
|
|
Analog zu xai_service.py für xAI Files API.
|
|
"""
|
|
import os
|
|
from typing import Dict, List, Any, Optional, AsyncIterator
|
|
from services.logging_utils import get_service_logger
|
|
|
|
|
|
class LangChainXAIService:
|
|
"""
|
|
Wrapper für LangChain ChatXAI mit Motia-Integration.
|
|
|
|
Benötigte Umgebungsvariablen:
|
|
- XAI_API_KEY: API Key für xAI (für ChatXAI model)
|
|
|
|
Usage:
|
|
service = LangChainXAIService(ctx)
|
|
model = service.get_chat_model(model="grok-4-1-fast-reasoning")
|
|
model_with_tools = service.bind_file_search(model, collection_id)
|
|
result = await service.invoke_chat(model_with_tools, messages)
|
|
"""
|
|
|
|
def __init__(self, ctx=None):
|
|
"""
|
|
Initialize LangChain xAI Service.
|
|
|
|
Args:
|
|
ctx: Optional Motia context for logging
|
|
|
|
Raises:
|
|
ValueError: If XAI_API_KEY not configured
|
|
"""
|
|
self.api_key = os.getenv('XAI_API_KEY', '')
|
|
self.ctx = ctx
|
|
self.logger = get_service_logger('langchain_xai', ctx)
|
|
|
|
if not self.api_key:
|
|
raise ValueError("XAI_API_KEY not configured in environment")
|
|
|
|
def _log(self, msg: str, level: str = 'info') -> None:
|
|
"""Delegate logging to service logger"""
|
|
log_func = getattr(self.logger, level, self.logger.info)
|
|
log_func(msg)
|
|
|
|
def get_chat_model(
|
|
self,
|
|
model: str = "grok-4-1-fast-reasoning",
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
):
|
|
"""
|
|
Initialisiert ChatXAI Model.
|
|
|
|
Args:
|
|
model: Model name (default: grok-4-1-fast-reasoning)
|
|
temperature: Sampling temperature 0.0-1.0
|
|
max_tokens: Optional max tokens for response
|
|
|
|
Returns:
|
|
ChatXAI model instance
|
|
|
|
Raises:
|
|
ImportError: If langchain_xai not installed
|
|
"""
|
|
try:
|
|
from langchain_xai import ChatXAI
|
|
except ImportError:
|
|
raise ImportError(
|
|
"langchain_xai not installed. "
|
|
"Run: pip install langchain-xai>=0.2.0"
|
|
)
|
|
|
|
self._log(f"🤖 Initializing ChatXAI: model={model}, temp={temperature}")
|
|
|
|
kwargs = {
|
|
"model": model,
|
|
"api_key": self.api_key,
|
|
"temperature": temperature
|
|
}
|
|
if max_tokens:
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
return ChatXAI(**kwargs)
|
|
|
|
def bind_tools(
|
|
self,
|
|
model,
|
|
collection_id: Optional[str] = None,
|
|
enable_web_search: bool = False,
|
|
web_search_config: Optional[Dict[str, Any]] = None,
|
|
max_num_results: int = 10
|
|
):
|
|
"""
|
|
Bindet xAI Tools (file_search und/oder web_search) an Model.
|
|
|
|
Args:
|
|
model: ChatXAI model instance
|
|
collection_id: Optional xAI Collection ID für file_search
|
|
enable_web_search: Enable web search tool (default: False)
|
|
web_search_config: Optional web search configuration:
|
|
{
|
|
'allowed_domains': ['example.com'], # Max 5 domains
|
|
'excluded_domains': ['spam.com'], # Max 5 domains
|
|
'enable_image_understanding': True
|
|
}
|
|
max_num_results: Max results from file search (default: 10)
|
|
|
|
Returns:
|
|
Model with requested tools bound (file_search and/or web_search)
|
|
"""
|
|
tools = []
|
|
|
|
# Add file_search tool if collection_id provided
|
|
if collection_id:
|
|
self._log(f"🔍 Binding file_search: collection={collection_id}")
|
|
tools.append({
|
|
"type": "file_search",
|
|
"vector_store_ids": [collection_id],
|
|
"max_num_results": max_num_results
|
|
})
|
|
|
|
# Add web_search tool if enabled
|
|
if enable_web_search:
|
|
self._log("🌐 Binding web_search")
|
|
web_search_tool = {"type": "web_search"}
|
|
|
|
# Add optional web search filters
|
|
if web_search_config:
|
|
if 'allowed_domains' in web_search_config:
|
|
domains = web_search_config['allowed_domains'][:5] # Max 5
|
|
web_search_tool['filters'] = {'allowed_domains': domains}
|
|
self._log(f" Allowed domains: {domains}")
|
|
elif 'excluded_domains' in web_search_config:
|
|
domains = web_search_config['excluded_domains'][:5] # Max 5
|
|
web_search_tool['filters'] = {'excluded_domains': domains}
|
|
self._log(f" Excluded domains: {domains}")
|
|
|
|
if web_search_config.get('enable_image_understanding'):
|
|
web_search_tool['enable_image_understanding'] = True
|
|
self._log(" Image understanding: enabled")
|
|
|
|
tools.append(web_search_tool)
|
|
|
|
if not tools:
|
|
self._log("⚠️ No tools to bind (no collection_id and web_search disabled)", level='warn')
|
|
return model
|
|
|
|
self._log(f"🔧 Binding {len(tools)} tool(s) to model")
|
|
return model.bind_tools(tools)
|
|
|
|
def bind_file_search(
|
|
self,
|
|
model,
|
|
collection_id: str,
|
|
max_num_results: int = 10
|
|
):
|
|
"""
|
|
Legacy method: Bindet nur file_search Tool an Model.
|
|
|
|
Use bind_tools() for more flexibility.
|
|
"""
|
|
return self.bind_tools(
|
|
model=model,
|
|
collection_id=collection_id,
|
|
max_num_results=max_num_results
|
|
)
|
|
|
|
async def invoke_chat(
|
|
self,
|
|
model,
|
|
messages: List[Dict[str, Any]]
|
|
) -> Any:
|
|
"""
|
|
Non-streaming Chat Completion.
|
|
|
|
Args:
|
|
model: ChatXAI model (with or without tools)
|
|
messages: List of message dicts [{"role": "user", "content": "..."}]
|
|
|
|
Returns:
|
|
LangChain AIMessage with response
|
|
|
|
Raises:
|
|
Exception: If API call fails
|
|
"""
|
|
self._log(f"💬 Invoking chat: {len(messages)} messages", level='debug')
|
|
|
|
result = await model.ainvoke(messages)
|
|
|
|
self._log(f"✅ Response received: {len(result.content)} chars", level='debug')
|
|
return result
|
|
|
|
async def astream_chat(
|
|
self,
|
|
model,
|
|
messages: List[Dict[str, Any]]
|
|
) -> AsyncIterator:
|
|
"""
|
|
Streaming Chat Completion.
|
|
|
|
Args:
|
|
model: ChatXAI model (with or without tools)
|
|
messages: List of message dicts
|
|
|
|
Yields:
|
|
Chunks from streaming response
|
|
|
|
Example:
|
|
async for chunk in service.astream_chat(model, messages):
|
|
delta = chunk.content if hasattr(chunk, "content") else ""
|
|
# Process delta...
|
|
"""
|
|
self._log(f"💬 Streaming chat: {len(messages)} messages", level='debug')
|
|
|
|
async for chunk in model.astream(messages):
|
|
yield chunk
|