
企业微信

飞书
选择您喜欢的方式加入群聊

扫码添加咨询专家
Text-to-SQL 技术让非技术人员也能用自然语言查询数据库,但在实际应用中,准确率是最大的挑战。一个 60% 准确率的系统几乎无法使用(10 次查询有 4 次错误),而 95% 准确率才能达到生产可用的标准。
本文将系统性地分享我们在 AskTable 项目中,如何将 Text-to-SQL 准确率从 60% 提升到 95% 的完整技术路径。
Text-to-SQL 的准确性有多个层次:
层次 1:语法正确性
-- 用户问:"本月销售额"
-- 生成的 SQL 能否成功执行?
✅ 正确:SELECT SUM(amount) FROM orders WHERE ...
❌ 错误:SELECT SUM(amount FROM orders WHERE ... -- 语法错误
层次 2:语义正确性
-- 用户问:"本月销售额"
-- 生成的 SQL 是否符合用户意图?
✅ 正确:WHERE created_at >= '2026-03-01' AND created_at < '2026-04-01'
❌ 错误:WHERE created_at >= '2026-02-01' AND created_at < '2026-03-01' -- 时间范围错误
层次 3:业务逻辑正确性
-- 用户问:"本月销售额"
-- 是否应用了正确的业务规则?
✅ 正确:WHERE status IN ('paid', 'completed') AND amount > 0
❌ 错误:没有过滤订单状态,包含了未支付和取消的订单
方法 1:执行成功率
执行成功率 = 成功执行的查询数 / 总查询数
局限性:只能判断语法正确性,无法判断语义是否正确。
方法 2:结果一致性
结果一致性 = 结果与标准答案一致的查询数 / 总查询数
实施方式:
方法 3:人工评估
人工准确率 = 人工判断为正确的查询数 / 总查询数
评估标准:
我们的测试方法:
技术栈:
Prompt 示例:
你是一个 SQL 专家。根据用户的自然语言问题,生成对应的 SQL 查询。
数据库 Schema:
- orders 表:order_id, user_id, amount, status, created_at
- users 表:user_id, name, email, region
示例:
问题:今天的订单数量
SQL:SELECT COUNT(*) FROM orders WHERE DATE(created_at) = CURDATE()
问题:{user_question}
SQL:
问题 1:表关系理解错误
用户问:"每个用户的订单总额"
错误 SQL:
SELECT user_id, SUM(amount)
FROM orders
GROUP BY user_id
问题:没有 JOIN users 表,无法显示用户名称
问题 2:时间处理不准确
用户问:"上个月的销售额"
错误 SQL:
SELECT SUM(amount)
FROM orders
WHERE created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH)
问题:"上个月"应该是完整的自然月,而不是"过去 30 天"
问题 3:业务规则缺失
用户问:"本月销售额"
错误 SQL:
SELECT SUM(amount)
FROM orders
WHERE MONTH(created_at) = MONTH(NOW())
问题:
- 没有过滤订单状态(包含了未支付订单)
- 没有排除测试订单
- 没有处理退款
问题 4:复杂查询失败
用户问:"对比今年和去年同期的销售额增长率"
错误:生成的 SQL 逻辑混乱,无法执行
| 查询类型 | 准确率 | 主要问题 |
|---|---|---|
| 简单查询(单表) | 85% | 时间处理、业务规则 |
| 多表关联 | 50% | JOIN 条件错误 |
| 聚合分析 | 70% | GROUP BY 遗漏 |
| 时间对比 | 40% | 复杂逻辑理解 |
| 总体 | 60% | - |
策略 1:增强 Schema 描述
之前:
orders 表:order_id, user_id, amount, status, created_at
改进后:
orders 表(订单表):
- order_id (int, 主键): 订单 ID
- user_id (int, 外键 -> users.user_id): 用户 ID
- amount (decimal): 订单金额(单位:元)
- status (varchar): 订单状态
可选值: 'pending'(待支付), 'paid'(已支付), 'cancelled'(已取消)
- created_at (datetime): 订单创建时间
- paid_at (datetime): 支付时间(可为 NULL)
业务规则:
- 统计销售额时,只计算 status='paid' 的订单
- 排除 user_id < 10000 的测试订单
策略 2:提供更多示例(Few-shot Learning)
之前:1-2 个简单示例
改进后:10-15 个覆盖不同场景的示例
示例 1(简单查询):
问题:今天的订单数量
SQL:SELECT COUNT(*) FROM orders WHERE DATE(created_at) = CURDATE()
示例 2(多表关联):
问题:每个用户的订单总额
SQL:
SELECT u.name, SUM(o.amount) as total
FROM users u
LEFT JOIN orders o ON u.user_id = o.user_id
WHERE o.status = 'paid'
GROUP BY u.user_id, u.name
示例 3(时间范围):
问题:上个月的销售额
SQL:
SELECT SUM(amount) as gmv
FROM orders
WHERE status = 'paid'
AND created_at >= DATE_FORMAT(DATE_SUB(NOW(), INTERVAL 1 MONTH), '%Y-%m-01')
AND created_at < DATE_FORMAT(NOW(), '%Y-%m-01')
示例 4(同比对比):
问题:对比今年和去年同期的销售额
SQL:
SELECT
YEAR(created_at) as year,
SUM(amount) as gmv
FROM orders
WHERE status = 'paid'
AND created_at >= DATE_SUB(NOW(), INTERVAL 2 YEAR)
GROUP BY YEAR(created_at)
策略 3:Chain-of-Thought(思维链)
引导模型先分析问题,再生成 SQL:
请按以下步骤分析:
1. 用户想查询什么指标?
2. 需要哪些表?
3. 表之间如何关联?
4. 需要什么过滤条件?
5. 需要分组或排序吗?
然后生成 SQL。
问题:本月各地区的销售额排名
分析:
1. 指标:销售额(SUM(amount))
2. 表:orders(订单表)、users(用户表,包含地区信息)
3. 关联:orders.user_id = users.user_id
4. 过滤:本月(created_at >= 本月1号)、已支付(status='paid')
5. 分组:按地区(users.region),排序:按销售额降序
SQL:
SELECT
u.region,
SUM(o.amount) as gmv
FROM orders o
JOIN users u ON o.user_id = u.user_id
WHERE o.status = 'paid'
AND o.created_at >= DATE_FORMAT(NOW(), '%Y-%m-01')
GROUP BY u.region
ORDER BY gmv DESC
| 查询类型 | 之前 | 改进后 | 提升 |
|---|---|---|---|
| 简单查询 | 85% | 92% | +7% |
| 多表关联 | 50% | 70% | +20% |
| 聚合分析 | 70% | 80% | +10% |
| 时间对比 | 40% | 65% | +25% |
| 总体 | 60% | 75% | +15% |
不要让 AI 每次都重新理解业务规则,而是将业务规则封装为可复用的组件。
指标定义:
指标:
- 名称: 销售额
英文: GMV
定义: 已支付订单的金额总和
SQL 模板: |
SELECT SUM(amount) as gmv
FROM orders
WHERE status = 'paid'
AND user_id >= 10000 -- 排除测试用户
AND amount > 0
{time_filter} -- 时间过滤占位符
{additional_filter} -- 额外过滤占位符
同义词: [营收, 交易额, 成交额]
- 名称: 订单量
英文: Order Count
SQL 模板: |
SELECT COUNT(*) as order_count
FROM orders
WHERE status = 'paid'
AND user_id >= 10000
{time_filter}
同义词: [订单数, 成交单数]
维度定义:
维度:
- 名称: 时间
字段: orders.created_at
类型: datetime
预定义范围:
今天: DATE(created_at) = CURDATE()
昨天: DATE(created_at) = DATE_SUB(CURDATE(), INTERVAL 1 DAY)
本周: created_at >= DATE_SUB(CURDATE(), INTERVAL WEEKDAY(CURDATE()) DAY)
本月: created_at >= DATE_FORMAT(CURDATE(), '%Y-%m-01')
上月: |
created_at >= DATE_SUB(DATE_FORMAT(CURDATE(), '%Y-%m-01'), INTERVAL 1 MONTH)
AND created_at < DATE_FORMAT(CURDATE(), '%Y-%m-01')
- 名称: 地区
字段: users.region
类型: string
可选值: [华东, 华北, 华南, 华中, 西南, 西北, 东北]
关联表: users
关联条件: orders.user_id = users.user_id
步骤 1:意图识别
用户问:"本月各地区的销售额"
识别结果:
- 指标:销售额(GMV)
- 维度:地区(region)
- 时间范围:本月
步骤 2:查询语义层
获取指标定义:
- 销售额的 SQL 模板
- 需要的表:orders
- 业务规则:status='paid', user_id>=10000
获取维度定义:
- 地区字段:users.region
- 需要关联:orders.user_id = users.user_id
获取时间定义:
- 本月:created_at >= DATE_FORMAT(CURDATE(), '%Y-%m-01')
步骤 3:组装 SQL
SELECT
u.region,
SUM(o.amount) as gmv
FROM orders o
JOIN users u ON o.user_id = u.user_id
WHERE o.status = 'paid'
AND o.user_id >= 10000
AND o.amount > 0
AND o.created_at >= DATE_FORMAT(CURDATE(), '%Y-%m-01')
GROUP BY u.region
ORDER BY gmv DESC
| 查询类型 | 之前 | 改进后 | 提升 |
|---|---|---|---|
| 简单查询 | 92% | 95% | +3% |
| 多表关联 | 70% | 85% | +15% |
| 聚合分析 | 80% | 90% | +10% |
| 时间对比 | 65% | 80% | +15% |
| 总体 | 75% | 85% | +10% |
GPT-3.5-turbo → GPT-4
| 维度 | GPT-3.5-turbo | GPT-4 |
|---|---|---|
| 复杂推理能力 | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
| SQL 生成准确率 | 75% | 88% |
| 响应速度 | 快(2-3秒) | 慢(5-8秒) |
| 成本 | 低 | 高(10倍) |
决策:
判断标准:
def should_use_gpt4(question):
# 包含复杂关键词
complex_keywords = ['对比', '同比', '环比', '增长率', '占比', '排名', '趋势']
if any(kw in question for kw in complex_keywords):
return True
# 涉及多个时间维度
if question.count('年') + question.count('月') + question.count('周') > 1:
return True
# 涉及多个指标
metrics = ['销售额', '订单量', '客单价', '用户数']
if sum(1 for m in metrics if m in question) > 1:
return True
return False
训练数据准备:
[
{
"messages": [
{"role": "system", "content": "你是一个专业的 SQL 生成助手..."},
{"role": "user", "content": "本月各地区的销售额"},
{"role": "assistant", "content": "SELECT u.region, SUM(o.amount) as gmv FROM orders o JOIN users u ON o.user_id = u.user_id WHERE o.status = 'paid' AND o.created_at >= DATE_FORMAT(CURDATE(), '%Y-%m-01') GROUP BY u.region"}
]
},
// ... 1000+ 个训练样本
]
微调效果:
测试集构建:
测试用例:
- id: test_001
问题: 今天的订单数量
标准 SQL: SELECT COUNT(*) FROM orders WHERE DATE(created_at) = CURDATE() AND status = 'paid'
预期结果: 数值类型
- id: test_002
问题: 本月各地区的销售额
标准 SQL: |
SELECT u.region, SUM(o.amount) as gmv
FROM orders o
JOIN users u ON o.user_id = u.user_id
WHERE o.status = 'paid'
AND o.created_at >= DATE_FORMAT(CURDATE(), '%Y-%m-01')
GROUP BY u.region
预期结果: 多行,包含 region 和 gmv 列
# ... 500 个测试用例
测试流程:
def run_test_suite():
results = []
for test_case in test_cases:
# 1. 用 Text-to-SQL 生成 SQL
generated_sql = text_to_sql(test_case['问题'])
# 2. 执行生成的 SQL
generated_result = execute_sql(generated_sql)
# 3. 执行标准 SQL
expected_result = execute_sql(test_case['标准 SQL'])
# 4. 对比结果
is_correct = compare_results(generated_result, expected_result)
results.append({
'test_id': test_case['id'],
'question': test_case['问题'],
'generated_sql': generated_sql,
'is_correct': is_correct
})
# 5. 计算准确率
accuracy = sum(r['is_correct'] for r in results) / len(results)
return accuracy, results
问题:生成的 SQL 和标准 SQL 可能写法不同,但结果相同。
示例:
-- 标准 SQL
SELECT SUM(amount) FROM orders WHERE status = 'paid'
-- 生成的 SQL(等价)
SELECT SUM(amount) as total FROM orders WHERE status IN ('paid')
解决方案:
def are_sqls_equivalent(sql1, sql2):
# 1. 执行两个 SQL
result1 = execute_sql(sql1)
result2 = execute_sql(sql2)
# 2. 对比结果(忽略列名、顺序)
return compare_results(result1, result2, ignore_column_names=True)
错误分类:
错误类型统计(基于 500 个测试用例):
- 表关系错误:15 个(3%)
- 时间处理错误:10 个(2%)
- 聚合逻辑错误:8 个(1.6%)
- 业务规则遗漏:5 个(1%)
- 其他:7 个(1.4%)
总错误数:45 个
准确率:(500-45)/500 = 91%
针对性优化:
迭代后:
对于关键查询,引入人工审核:
def text_to_sql_with_review(question, user_role):
# 1. 生成 SQL
sql = generate_sql(question)
# 2. 判断是否需要人工审核
if needs_human_review(question, sql, user_role):
# 显示生成的 SQL,等待用户确认
return {
'sql': sql,
'status': 'pending_review',
'message': '请确认生成的 SQL 是否正确'
}
else:
# 直接执行
result = execute_sql(sql)
return {
'sql': sql,
'result': result,
'status': 'executed'
}
def needs_human_review(question, sql, user_role):
# 1. 涉及敏感数据
if 'DELETE' in sql or 'UPDATE' in sql:
return True
# 2. 查询结果可能影响重大决策
if user_role == 'executive' and '销售额' in question:
return True
# 3. 复杂查询(包含子查询、窗口函数)
if 'SELECT' in sql and sql.count('SELECT') > 1:
return True
return False
用户提问
↓
意图识别层(理解用户想查什么)
↓
语义层查询(获取指标、维度定义)
↓
SQL 生成层(组装 SQL)
↓
验证层(语法检查、权限检查)
↓
执行层(执行 SQL,返回结果)
不要追求一步到位,按优先级逐步优化:
第一阶段:解决语法错误(60% → 75%)
第二阶段:解决语义错误(75% → 85%)
第三阶段:解决复杂场景(85% → 95%)
监控指标:
迭代流程:
收集错误案例
↓
分析错误原因
↓
针对性优化(Prompt/语义层/模型)
↓
测试验证
↓
上线发布
↓
持续监控
收集反馈:
利用反馈:
| 场景 | 推荐模型 | 原因 |
|---|---|---|
| 简单查询 | GPT-3.5-turbo | 成本低、速度快 |
| 复杂查询 | GPT-4 / Claude 3 | 推理能力强 |
| 私有化部署 | Qwen-72B / DeepSeek | 开源、可本地部署 |
| 成本敏感 | 微调后的小模型 | 性价比高 |
| 工具 | 优点 | 缺点 |
|---|---|---|
| dbt | 成熟、社区活跃 | 学习曲线陡 |
| Cube.js | 开源、灵活 | 需要开发能力 |
| AskTable 内置 | 开箱即用、易配置 | 功能相对简单 |
| 自研 | 完全可控 | 开发成本高 |
将 Text-to-SQL 准确率从 60% 提升到 95% 是一个系统工程,需要:
技术层面:
工程层面:
业务层面:
核心原则:
95% 的准确率是生产可用的门槛,但仍有 5% 的错误。对于关键业务场景,建议:
Text-to-SQL 技术正在快速发展,随着大模型能力的提升,未来准确率有望达到 98% 甚至更高。但无论技术如何进步,业务语义层和测试验证体系都是不可或缺的基础设施。
了解更多: