生產環境跑LangGraph半年了,我整理了這份避坑指南 原創 精華
今年3月開始用LangGraph重構我們的AI系統,到現在已經快6個月了。期間踩了一些坑,有些問題官方文檔里根本沒提到,今天把這些經驗教訓整理出來。
先說結論
如果你的系統符合以下任何一個條件,LangGraph可能適合你:
- 需要復雜的多步驟決策流程
- 有明確的狀態管理需求
- 需要人工審核關鍵節點
- 要做多智能體協作
但如果只是簡單的單輪對話或者純粹的RAG,用LangChain就夠了,別給自己找麻煩。
狀態管理的坑
1. Checkpointer選擇決定生死
剛開始組里同事用InMemorySaver做測試,一切正常。上線后服務一重啟,所有對話歷史全沒了。
# ? 千萬別在生產環境這么干
from langgraph.checkpoint.memory import InMemorySaver
checkpointer = InMemorySaver() # 服務重啟就GG
# ? 生產環境的正確姿勢
from langgraph.checkpoint.postgres import PostgresSaver
from psycopg_pool import AsyncConnectionPool
# 使用連接池,不要每次都創建新連接
asyncdef create_checkpointer():
pool = AsyncConnectionPool(
"postgresql://user:pass@localhost/db",
min_size=10,
max_size=100,
max_idle=300.0, # 連接最大空閑時間
max_lifetime=3600.0# 連接最大生命周期
)
asyncwith pool.connection() as conn:
return PostgresSaver(conn)2. Thread ID管理
最開始我們用用戶ID做thread_id,結果一個用戶同時發起多個對話時狀態就串了。后來改成UUID,又發現無法追蹤用戶歷史。
最終方案:復合ID策略
import hashlib
from datetime import datetime
class ThreadManager:
@staticmethod
def generate_thread_id(user_id: str, session_type: str = "default"):
"""生成可追蹤的thread_id"""
# 格式:用戶ID_會話類型_時間戳_短hash
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
unique_str = f"{user_id}_{session_type}_{timestamp}"
short_hash = hashlib.md5(unique_str.encode()).hexdigest()[:8]
returnf"{user_id}_{session_type}_{timestamp}_{short_hash}"
@staticmethod
def parse_thread_id(thread_id: str):
"""解析thread_id獲取元信息"""
parts = thread_id.split("_")
return {
"user_id": parts[0],
"session_type": parts[1],
"timestamp": parts[2],
"hash": parts[3]
}3. 狀態版本控制
LangGraph 存儲每個 channel 值時都會進行版本控制,這樣每個新的 checkpoint 只存儲真正變化的值。但如果你的狀態結構經常變,會遇到兼容性問題。
from typing import TypedDict, Optional
from pydantic import BaseModel
# 使用版本化的狀態定義
class StateV1(TypedDict):
messages: list
context: dict
version: int # 始終包含版本號
class StateV2(TypedDict):
messages: list
context: dict
metadata: dict # V2新增字段
version: int
class StateMigrator:
"""狀態遷移器"""
@staticmethod
def migrate(state: dict) -> dict:
version = state.get("version", 1)
if version == 1:
# V1 -> V2遷移
state["metadata"] = {}
state["version"] = 2
# 未來可以繼續添加遷移邏輯
return state
@staticmethod
def load_state(thread_id: str, checkpointer):
"""加載并自動遷移狀態"""
state = checkpointer.get({"thread_id": thread_id})
if state:
return StateMigrator.migrate(state)
returnNone錯誤處理
1. 節點級重試機制
LangGraph 提供了 retry_policy 來重試失敗的節點,只有失敗的分支會被重試,不用擔心重復執行工作。但默認的重試策略太簡單了。
from langgraph.types import RetryPolicy
import httpx
from typing import Optional
class SmartRetryPolicy:
"""智能重試策略"""
@staticmethod
def create_policy(node_name: str) -> RetryPolicy:
# 根據節點類型設置不同的重試策略
if"llm"in node_name:
return RetryPolicy(
max_attempts=3,
backoff_factor=2.0,
max_interval=30.0,
retry_on=lambda e: SmartRetryPolicy.should_retry_llm(e)
)
elif"api"in node_name:
return RetryPolicy(
max_attempts=5,
backoff_factor=1.5,
max_interval=60.0,
retry_on=lambda e: SmartRetryPolicy.should_retry_api(e)
)
else:
# 默認策略
return RetryPolicy(max_attempts=2)
@staticmethod
def should_retry_llm(error: Exception) -> bool:
"""LLM調用是否需要重試"""
# 限流錯誤必須重試
if isinstance(error, httpx.HTTPStatusError):
return error.response.status_code in [429, 502, 503, 504]
# 網絡錯誤重試
if isinstance(error, (httpx.ConnectError, httpx.TimeoutException)):
returnTrue
# 參數錯誤不重試
if"invalid"in str(error).lower():
returnFalse
returnTrue
@staticmethod
def should_retry_api(error: Exception) -> bool:
"""API調用是否需要重試"""
if isinstance(error, httpx.HTTPStatusError):
# 5xx都重試,429限流也重試
return error.response.status_code >= 500or error.response.status_code == 429
return isinstance(error, (httpx.ConnectError, httpx.TimeoutException))
# 使用示例
builder = StateGraph(State)
builder.add_node(
"llm_node",
process_llm, retry_policy=SmartRetryPolicy.create_policy("llm_node")
)2. 全局錯誤恢復
節點重試解決不了所有問題,還需要全局的錯誤恢復機制:
from langgraph.errors import GraphRecursionError
import asyncio
from typing import Optional
class ResilientGraphRunner:
def __init__(self, graph, checkpointer):
self.graph = graph
self.checkpointer = checkpointer
self.dead_letter_queue = [] # 死信隊列
asyncdef run_with_recovery(
self,
input_data: dict,
thread_id: str,
max_recovery_attempts: int = 3
):
"""帶恢復機制的圖執行"""
attempt = 0
last_error = None
while attempt < max_recovery_attempts:
try:
config = {
"configurable": {
"thread_id": thread_id,
"recursion_limit": 100# 防止無限循環
}
}
# 嘗試執行
result = await self.graph.ainvoke(input_data, config)
return result
except GraphRecursionError as e:
# 遞歸深度超限,可能是死循環
await self.handle_recursion_error(thread_id, e)
break
except Exception as e:
last_error = e
attempt += 1
# 記錄錯誤
await self.log_error(thread_id, e, attempt)
# 嘗試從最后一個成功的checkpoint恢復
if attempt < max_recovery_attempts:
await self.recover_from_checkpoint(thread_id)
await asyncio.sleep(2 ** attempt) # 指數退避
# 所有重試都失敗,進入死信隊列
await self.send_to_dead_letter(thread_id, input_data, last_error)
raise last_error
asyncdef recover_from_checkpoint(self, thread_id: str):
"""從最后一個成功的checkpoint恢復"""
# 獲取最后一個成功的狀態
checkpoints = self.checkpointer.list(
{"configurable": {"thread_id": thread_id}},
limit=10
)
for checkpoint in checkpoints:
if checkpoint.metadata.get("status") == "success":
# 恢復到這個狀態
self.checkpointer.put(
{"configurable": {"thread_id": thread_id}},
checkpoint.checkpoint,
checkpoint.metadata
)
break3. 工具調用錯誤處理
工具節點現在會在tool call失敗時返回帶有error字段的ToolMessages,但默認處理太粗糙:
from langchain_core.messages import ToolMessage, AIMessage
from typing import List, Dict, Any
class SafeToolExecutor:
"""安全的工具執行器"""
def __init__(self, tools: List, fallback_model=None):
self.tools = {tool.name: tool for tool in tools}
self.fallback_model = fallback_model
self.execution_history = [] # 記錄執行歷史
asyncdef execute_with_fallback(
self,
tool_calls: List[Dict[str, Any]],
state: Dict
) -> List[ToolMessage]:
"""執行工具調用,失敗時有降級策略"""
results = []
for tool_call in tool_calls:
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
# 驗證工具是否存在
if tool_name notin self.tools:
results.append(ToolMessage(
content=f"Tool {tool_name} not found",
tool_call_id=tool_call.get("id"),
additional_kwargs={"error": "ToolNotFound"}
))
continue
# 執行工具
try:
result = await self.execute_single_tool(
tool_name,
tool_args,
state
)
results.append(ToolMessage(
content=str(result),
tool_call_id=tool_call.get("id")
))
except Exception as e:
# 記錄錯誤
self.execution_history.append({
"tool": tool_name,
"args": tool_args,
"error": str(e),
"timestamp": datetime.now()
})
# 嘗試降級策略
fallback_result = await self.try_fallback(
tool_name,
tool_args,
e,
state
)
results.append(ToolMessage(
content=fallback_result,
tool_call_id=tool_call.get("id"),
additional_kwargs={
"error": str(e),
"fallback_used": True
}
))
return results
asyncdef try_fallback(
self,
tool_name: str,
args: dict,
error: Exception,
state: dict
) -> str:
"""降級策略"""
# 策略1:使用備用工具
backup_tool = self.get_backup_tool(tool_name)
if backup_tool:
try:
returnawait backup_tool.arun(**args)
except:
pass
# 策略2:使用LLM模擬
if self.fallback_model:
prompt = f"""
工具 {tool_name} 執行失敗。
參數:{args}
錯誤:{error}
請基于當前上下文提供一個合理的替代回答。
上下文:{state.get('context', '')}
"""
returnawait self.fallback_model.ainvoke(prompt)
# 策略3:返回有意義的錯誤信息
returnf"工具執行失敗,請嘗試其他方式:{error}"性能優化
1. 并行執行的正確姿勢
很多人不知道LangGraph支持自動并行執行:
from langgraph.graph import StateGraph, START, END
from typing import Literal
class OptimizedGraph:
@staticmethod
def build_parallel_graph():
builder = StateGraph(State)
# 這些節點會自動并行執行!
def route_parallel(state) -> List[str]:
"""返回多個節點名,它們會并行執行"""
tasks = []
if state.get("need_search"):
tasks.append("search_node")
if state.get("need_calculation"):
tasks.append("calc_node")
if state.get("need_validation"):
tasks.append("validate_node")
return tasks if tasks else ["default_node"]
# 添加條件邊實現并行
builder.add_conditional_edges(
START,
route_parallel,
# 這些節點會并行執行
["search_node", "calc_node", "validate_node", "default_node"]
)
# Fan-in:所有并行節點完成后匯總
builder.add_edge(["search_node", "calc_node", "validate_node"], "aggregate_node")
return builder.compile()2. 節點緩存機制
LangGraph 現在支持節點級緩存,可以緩存單個節點的結果,減少重復計算并加速執行:
from functools import lru_cache
import hashlib
import pickle
class NodeCache:
"""節點級緩存"""
def __init__(self, redis_client=None):
self.redis = redis_client
self.local_cache = {} # 本地緩存作為一級緩存
def cache_key(self, node_name: str, state: dict) -> str:
"""生成緩存鍵"""
# 只用關鍵字段生成key,忽略無關字段
relevant_fields = self.get_relevant_fields(node_name)
cache_data = {k: state.get(k) for k in relevant_fields}
# 生成穩定的hash
data_str = pickle.dumps(cache_data, protocol=pickle.HIGHEST_PROTOCOL)
returnf"node:{node_name}:{hashlib.md5(data_str).hexdigest()}"
def get_relevant_fields(self, node_name: str) -> List[str]:
"""獲取節點相關的狀態字段"""
# 不同節點關注不同字段
field_map = {
"search_node": ["query", "filters"],
"llm_node": ["messages", "temperature"],
"calc_node": ["formula", "variables"]
}
return field_map.get(node_name, ["messages"])
asyncdef get_or_compute(
self,
node_name: str,
state: dict,
compute_func,
ttl: int = 3600
):
"""獲取緩存或計算"""
cache_key = self.cache_key(node_name, state)
# 一級緩存:內存
if cache_key in self.local_cache:
return self.local_cache[cache_key]
# 二級緩存:Redis
if self.redis:
cached = await self.redis.get(cache_key)
if cached:
result = pickle.loads(cached)
self.local_cache[cache_key] = result
return result
# 計算并緩存
result = await compute_func(state)
# 寫入緩存
self.local_cache[cache_key] = result
if self.redis:
await self.redis.set(
cache_key,
pickle.dumps(result),
expire=ttl
)
return result
# 使用緩存裝飾器
def cached_node(ttl=3600):
def decorator(func):
asyncdef wrapper(state: dict, cache: NodeCache):
returnawait cache.get_or_compute(
func.__name__,
state,
func,
ttl
)
return wrapper
return decorator
@cached_node(ttl=7200)
asyncdef expensive_search_node(state: dict):
"""昂貴的搜索操作,結果會被緩存"""
# 實際的搜索邏輯
results = await perform_search(state["query"])
return {"search_results": results}3. 流式輸出優化
前面提到的stream_mode選擇很重要,但還有其他優化點:
class StreamOptimizer:
"""流式輸出優化器"""
@staticmethod
asyncdef optimized_stream(graph, input_data, config):
"""優化的流式處理"""
# 使用updates模式減少傳輸量
asyncfor chunk in graph.astream(
input_data,
config,
stream_mode="updates"
):
# 只處理真正需要的更新
for node_name, updates in chunk.items():
# 過濾掉內部狀態更新
filtered_updates = StreamOptimizer.filter_updates(updates)
if filtered_updates:
# 壓縮大對象
compressed = StreamOptimizer.compress_if_needed(filtered_updates)
yield node_name, compressed
@staticmethod
def filter_updates(updates: dict) -> dict:
"""過濾不必要的更新"""
# 這些字段不需要傳給客戶端
internal_fields = [
"_raw_response",
"_checkpoint_data",
"_debug_info",
"tool_calls"# 客戶端通常不需要看到具體的工具調用
]
return {
k: v for k, v in updates.items()
if k notin internal_fields andnot k.startswith("_")
}
@staticmethod
def compress_if_needed(data: dict) -> dict:
"""壓縮大對象"""
import sys
import gzip
import base64
for key, value in data.items():
# 超過10KB的字符串進行壓縮
if isinstance(value, str) and sys.getsizeof(value) > 10240:
compressed = gzip.compress(value.encode())
data[key] = {
"compressed": True,
"data": base64.b64encode(compressed).decode()
}
# 大列表只傳摘要
elif isinstance(value, list) and len(value) > 100:
data[key] = {
"summary": f"List with {len(value)} items",
"preview": value[:10], # 只傳前10個
"total": len(value)
}
return data生產部署的細節
1. 多環境配置管理
from enum import Enum
from pydantic import BaseSettings
class Environment(Enum):
DEV = "dev"
STAGING = "staging"
PROD = "prod"
class LangGraphConfig(BaseSettings):
"""配置管理"""
environment: Environment
# 數據庫配置
postgres_url: str
postgres_pool_size: int = 20
# Redis配置
redis_url: str
redis_pool_size: int = 50
# LLM配置
openai_api_key: str
openai_timeout: int = 30
openai_max_retries: int = 3
# Graph配置
max_recursion_depth: int = 100
default_thread_ttl: int = 86400# 24小時
# 監控配置
enable_tracing: bool = True
langsmith_api_key: str = None
class Config:
env_file = f".env.{Environment.PROD.value}"
def get_checkpointer_config(self):
"""根據環境返回不同的checkpointer配置"""
if self.environment == Environment.DEV:
# 開發環境用內存
return {"type": "memory"}
elif self.environment == Environment.STAGING:
# 測試環境用SQLite
return {
"type": "sqlite",
"path": "checkpoints.db"
}
else:
# 生產環境用PostgreSQL
return {
"type": "postgres",
"url": self.postgres_url,
"pool_size": self.postgres_pool_size
}2. 監控和可觀測性
from dataclasses import dataclass
from datetime import datetime
import json
@dataclass
class GraphMetrics:
"""圖執行指標"""
thread_id: str
start_time: datetime
end_time: datetime
total_nodes_executed: int
failed_nodes: List[str]
retry_count: int
total_tokens: int
total_cost: float
class MetricsCollector:
"""指標收集器"""
def __init__(self, prometheus_client=None):
self.prometheus = prometheus_client
self.metrics_buffer = []
asyncdef track_node_execution(self, node_name: str, duration: float, success: bool):
"""追蹤節點執行"""
if self.prometheus:
self.prometheus.histogram(
"langgraph_node_duration",
duration,
labels={"node": node_name, "success": str(success)}
)
self.prometheus.increment(
"langgraph_node_executions",
labels={"node": node_name, "status": "success"if success else"failure"}
)
asyncdef track_graph_execution(self, metrics: GraphMetrics):
"""追蹤整個圖的執行"""
# 發送到監控系統
if self.prometheus:
duration = (metrics.end_time - metrics.start_time).total_seconds()
self.prometheus.histogram(
"langgraph_graph_duration",
duration
)
self.prometheus.gauge(
"langgraph_graph_cost",
metrics.total_cost
)
# 存儲詳細日志用于分析
self.metrics_buffer.append(metrics)
# 定期批量寫入
if len(self.metrics_buffer) >= 100:
await self.flush_metrics()
asyncdef flush_metrics(self):
"""批量寫入指標"""
ifnot self.metrics_buffer:
return
# 寫入數據倉庫或日志系統
batch_data = [
json.dumps(m.__dict__, default=str)
for m in self.metrics_buffer
]
# 實際寫入邏輯
await write_to_datawarehouse(batch_data)
self.metrics_buffer.clear()3. 負載均衡和擴展
class GraphPoolManager:
"""圖實例池管理"""
def __init__(self, min_instances=2, max_instances=10):
self.min_instances = min_instances
self.max_instances = max_instances
self.instances = []
self.current_index = 0
asyncdef get_instance(self) -> CompiledGraph:
"""輪詢獲取圖實例"""
ifnot self.instances:
await self.initialize_pool()
# 簡單輪詢
instance = self.instances[self.current_index]
self.current_index = (self.current_index + 1) % len(self.instances)
return instance
asyncdef scale_based_on_load(self, current_qps: float):
"""基于負載動態擴縮容"""
target_instances = self.calculate_target_instances(current_qps)
current_count = len(self.instances)
if target_instances > current_count:
# 擴容
for _ in range(target_instances - current_count):
self.instances.append(await self.create_instance())
elif target_instances < current_count:
# 縮容
excess = current_count - target_instances
for _ in range(excess):
instance = self.instances.pop()
await self.destroy_instance(instance)
def calculate_target_instances(self, qps: float) -> int:
"""計算需要的實例數"""
# 每個實例處理100 QPS
target = int(qps / 100) + 1
return max(self.min_instances, min(target, self.max_instances))踩坑總結
必須記住的點
- Checkpointer是必需的- 創建自定義checkpoint saver時,考慮實現異步版本以避免阻塞主線程
- 合理設置遞歸限制- 默認的遞歸限制可能不夠,但設太高會導致死循環難以發現
- 工具錯誤要優雅處理- 工具調用失敗是常態,不是異常
- 狀態更新要原子化- 并行節點更新同一個字段會有競態條件
- 監控要從第一天開始- 不要等出問題了才加監控
什么時候不該用LangGraph
- 簡單的問答系統 - 直接用LangChain
- 純流式生成 - 用Streaming API就夠了
- 無狀態的API調用 - FastAPI更合適
- 極度延遲敏感的場景 - 圖遍歷有開銷
總結
LangGraph確實強大,但它不是銀彈。用對了地方能讓你的系統脫胎換骨,用錯了地方就是過度設計。
最重要的經驗:從簡單開始,逐步復雜化。別一上來就搞20個節點的復雜圖,先從3-5個節點開始,跑穩定了再加功能。
還有,LangGraph的graph-based架構確實提供了很大的靈活性,可以從完全開放的agent到完全確定的流程。
如果你也在用LangGraph,歡迎交流踩坑經驗。生產環境的坑,只有真正跑過的人才知道有多深。
文中的代碼都是從生產代碼簡化而來,直接復制可能需要調整。關鍵是理解思路,而不是照抄代碼。
本文轉載自??AI 博物院?? 作者:longyunfeigu

















