
sidebar.wechat

sidebar.feishu
sidebar.chooseYourWayToJoin

sidebar.scanToAddConsultant
AI 生成的 SQL 可能因为表名错误、字段不存在、语法错误等原因执行失败。如何让 Agent 从错误中学习,自动纠正并重试?
AskTable 的 Agent 自我纠错机制,通过 错误检测 + Prompt 调整 + Case 收集 的闭环,实现了持续优化。
表名错误:
-- AI 生成
SELECT * FROM order -- 错误:应该是 orders
-- 错误信息
Table 'order' doesn't exist
字段不存在:
-- AI 生成
SELECT user_name FROM users -- 错误:应该是 username
-- 错误信息
Unknown column 'user_name' in 'field list'
语法错误:
-- AI 生成
SELECT * FROM orders WHERE -- 错误:WHERE 后缺少条件
-- 错误信息
You have an error in your SQL syntax
直接返回错误:
无限重试:
class SQLError(Enum):
TABLE_NOT_FOUND = "table_not_found"
COLUMN_NOT_FOUND = "column_not_found"
SYNTAX_ERROR = "syntax_error"
PERMISSION_DENIED = "permission_denied"
TIMEOUT = "timeout"
UNKNOWN = "unknown"
def classify_error(error_message: str) -> SQLError:
"""分类 SQL 错误"""
if "doesn't exist" in error_message.lower():
return SQLError.TABLE_NOT_FOUND
elif "unknown column" in error_message.lower():
return SQLError.COLUMN_NOT_FOUND
elif "syntax" in error_message.lower():
return SQLError.SYNTAX_ERROR
elif "permission denied" in error_message.lower():
return SQLError.PERMISSION_DENIED
elif "timeout" in error_message.lower():
return SQLError.TIMEOUT
else:
return SQLError.UNKNOWN
async def execute_with_retry(
question: str,
datasource: DataSourceAdmin,
max_retries: int = 3,
) -> QueryResult:
"""带重试的 SQL 执行"""
error_history: list[dict] = []
for attempt in range(max_retries):
try:
# 1. 生成 SQL
sql = await generate_sql(
question,
datasource,
error_history=error_history,
)
# 2. 执行 SQL
df = await datasource.execute_sql(sql)
# 3. 成功,保存 Good Case
await save_good_case(question, sql, df)
return QueryResult(sql=sql, dataframe=df)
except Exception as e:
# 4. 失败,记录错误
error_type = classify_error(str(e))
error_history.append({
"sql": sql,
"error": str(e),
"error_type": error_type,
"attempt": attempt + 1,
})
# 5. 最后一次重试失败
if attempt == max_retries - 1:
await save_bad_case(question, sql, str(e))
raise
# 6. 继续重试
log.warning(f"SQL execution failed (attempt {attempt + 1}): {e}")
async def generate_sql(
question: str,
datasource: DataSourceAdmin,
error_history: list[dict] | None = None,
) -> str:
"""生成 SQL,根据错误历史调整 Prompt"""
# 基础 Prompt
prompt = f"""
数据库:{datasource.to_markdown()}
问题:{question}
请生成 SQL 查询。
"""
# 如果有错误历史,添加纠错提示
if error_history:
prompt += "\n\n## 错误历史\n"
for error in error_history:
prompt += f"""
尝试 {error['attempt']}:
SQL: {error['sql']}
错误: {error['error']}
错误类型: {error['error_type']}
请根据错误信息修正 SQL。
"""
# 调用 LLM
response = await llm.generate(prompt)
return response["sql"]
async def save_good_case(
question: str,
sql: str,
dataframe: pd.DataFrame,
):
"""保存成功案例"""
await db.execute(
"""
INSERT INTO training_pairs (question, sql, status, created_at)
VALUES (:question, :sql, 'good', NOW())
""",
{"question": question, "sql": sql},
)
async def save_bad_case(
question: str,
sql: str,
error: str,
):
"""保存失败案例"""
await db.execute(
"""
INSERT INTO training_pairs (question, sql, error, status, created_at)
VALUES (:question, :sql, :error, 'bad', NOW())
""",
{"question": question, "sql": sql, "error": error},
)
# 第 1 次尝试
sql_1 = "SELECT * FROM order"
error_1 = "Table 'order' doesn't exist"
# 第 2 次尝试(根据错误调整)
sql_2 = "SELECT * FROM orders" # 成功!
# 第 1 次尝试
sql_1 = "SELECT user_name FROM users"
error_1 = "Unknown column 'user_name'"
# 第 2 次尝试
sql_2 = "SELECT username FROM users" # 成功!
max_retries = 3 # 最多重试 3 次
效果:
error_cache: dict[str, str] = {}
def get_cached_fix(sql: str, error: str) -> str | None:
"""获取缓存的修复方案"""
cache_key = f"{sql}:{error}"
return error_cache.get(cache_key)
效果:
AskTable Agent 的自我纠错机制,通过 错误检测 + 重试 + Case 收集 的闭环,实现了:
✅ 自动修复:常见错误自动纠正 ✅ 持续优化:从错误中学习 ✅ 用户体验:减少手动干预 ✅ 成本控制:限制重试次数
相关阅读:
技术交流:
sidebar.noProgrammingNeeded
sidebar.startFreeTrial