
企业微信

飞书
选择您喜欢的方式加入群聊

扫码添加咨询专家
在 Text-to-SQL 场景中,最大的挑战不是生成 SQL 语法,而是如何让 AI 找到正确的表和字段。一个企业数据库可能有数百张表、数千个字段,如何从用户的自然语言问题中精准定位到相关的元数据?
AskTable 的 Schema Linking 引擎,通过 向量检索 + 全文搜索 + Training Pair 复用 的混合策略,实现了高准确率的元数据检索。
本文将深入剖析这套引擎的设计与实现。
问题:用户问"上个月销售额最高的产品是什么?"
数据库:
orders, products, customers, sales_records, ...order_id, product_name, sale_amount, created_at, ...AI 需要回答:
orders? sales_records?)sale_amount? total_price?)created_at? order_date?)Schema Linking 的任务:从海量元数据中,找出与问题相关的表和字段。
全量传递:
# 将所有表和字段传递给 LLM prompt = f"数据库有以下表:{all_tables}\n问题:{question}"
问题:
关键词匹配:
# 简单的字符串匹配 relevant_tables = [t for t in all_tables if keyword in t.name]
问题:
加载图表中...
Qdrant(向量检索):
Meilisearch(全文搜索):
Training Pair(示例复用):
async def _rewrite_question(self, question: Question) -> None: """将用户问题改写为关键词和子查询""" if not question.subqueries: response = await prompt_generate( "query.extract_keywords_from_question", QUESTION=question.text, SPECIFICATION=question.specification, EVIDENCE=question.evidence, ) question.keywords = response["keywords"] question.subqueries = response["subqueries"] log.info(f"extracted keywords: {question.keywords}")
示例:
# 输入 question = "上个月销售额最高的产品是什么?" # 输出 keywords = ["销售额", "产品", "上个月"] subqueries = [ "查询销售记录表", "按产品分组统计销售额", "筛选上个月的数据" ]
关键点:
async def _retrieve_fields(self, queries: list[str]) -> list[RetrievedMetaEntity]: """通过向量检索查找相关字段""" if not queries: log.warning("No subqueries provided, skipping field retrieval") return [] return await self.ds.retrieve_fields_by_question(queries)
Qdrant 检索实现:
async def retrieve_fields_by_question( self, queries: list[str] ) -> list[RetrievedMetaEntity]: """向量检索字段""" # 1. 将查询转换为向量 query_vectors = await self.embedding_model.encode(queries) # 2. 在 Qdrant 中搜索 results = await self.qdrant_client.search( collection_name=f"meta_{self.id}", query_vector=query_vectors[0], limit=20, score_threshold=0.7, ) # 3. 返回检索结果 return [ { "id": hit.id, "payload": hit.payload, "score": hit.score, } for hit in results ]
关键点:
score_threshold=0.7 过滤低相关结果async def _retrieve_values(self, keywords: list[str]) -> list[RetrievedMetaEntity]: """通过全文搜索查找字段值""" if not config.aisearch_host or not config.aisearch_master_key: log.warning("Value index is not enabled, skipping value retrieval") elif not keywords: log.warning("No keywords provided, skipping value retrieval") else: values = await self.ds.retrieve_values_by_question(keywords) return values return []
Meilisearch 检索实现:
async def retrieve_values_by_question( self, keywords: list[str] ) -> list[RetrievedMetaEntity]: """全文搜索字段值""" # 1. 构建搜索查询 query = " ".join(keywords) # 2. 在 Meilisearch 中搜索 results = await self.meilisearch_client.index(f"values_{self.id}").search( query, limit=50, attributesToRetrieve=["schema_name", "table_name", "field_name", "value"], ) # 3. 返回检索结果 return [ { "id": hit["id"], "payload": { "schema_name": hit["schema_name"], "table_name": hit["table_name"], "field_name": hit["field_name"], "value": hit["value"], "type": "value", }, "score": hit["_rankingScore"], } for hit in results["hits"] ]
关键点:
async def _retrieve_examples(self, query: str) -> list[TrainingPair]: """检索相似的历史问题-SQL 对""" translation_examples = await retrieve_training_pairs( datasource_id=self.ds.id, query=query, role_id=self.role.id if self.role else None, ) return translation_examples
Training Pair 存储结构:
class TrainingPair(TypedDict): question: str # 历史问题 sql: str # 对应的 SQL score: float # 相似度分数
示例:
# 当前问题 question = "上个月销售额最高的产品是什么?" # 检索到的相似问题 training_pairs = [ { "question": "本月销售额最高的商品是哪个?", "sql": "SELECT product_name, SUM(amount) FROM sales WHERE month = CURRENT_MONTH GROUP BY product_name ORDER BY SUM(amount) DESC LIMIT 1", "score": 0.92, }, { "question": "去年销量最好的产品?", "sql": "SELECT product_id, COUNT(*) FROM orders WHERE year = LAST_YEAR GROUP BY product_id ORDER BY COUNT(*) DESC LIMIT 1", "score": 0.85, }, ]
关键点:
def _merge_values_fields(hits: list[RetrievedMetaEntity]) -> list[MetaEntity]: """合并字段和值检索结果""" fields_buckets: dict[tuple, set] = {} for hit in hits: index = ( hit["payload"]["schema_name"], hit["payload"]["table_name"], hit["payload"]["field_name"], ) if not fields_buckets.get(index): fields_buckets[index] = set() bucket = fields_buckets[index] if hit["payload"]["type"] == "value": bucket.add(hit["payload"]["value"]) fields_list: list[MetaEntity] = [] for index, values in fields_buckets.items(): fields_list.append( { "schema_name": index[0], "table_name": index[1], "field_name": index[2], "sample_values": list(values), } ) return fields_list
关键点:
def _add_context_to_meta(meta: MetaAdmin, entities: list[MetaEntity]): """将检索到的字段值注入到元数据描述中""" for entity in entities: if schema := meta.schemas.get(entity["schema_name"]): if table := schema.tables.get(entity["table_name"]): if field := table.fields.get(entity["field_name"]): values = [f'"{v}"' for v in entity["sample_values"]] if values: if field.curr_desc: field.curr_desc += f"(e.g. {','.join(values)})" else: field.curr_desc = f"(e.g. {','.join(values)})"
效果:
# 原始元数据 field = { "name": "status", "type": "VARCHAR", "description": "订单状态" } # 注入上下文后 field = { "name": "status", "type": "VARCHAR", "description": "订单状态(e.g. \"已完成\",\"待支付\",\"已取消\")" }
关键点:
async def _pick_tables( self, meta_candidate: MetaAdmin, specification: str, training_pairs: list[TrainingPair], ) -> list[tuple[str, str]]: """通过 LLM 重新排序和选择相关表""" # 1. 让 LLM 选择最相关的表 table_of_interest_ = await prompt_generate( "query.select_tables_by_question", meta_data=meta_candidate.to_markdown(), question=specification, translation_examples=dict_to_markdown(training_pairs), ) table_of_interest = table_of_interest_["table_names"] if not table_of_interest: raise errors.NoDataToQuery(params={"message": "No data to query"}) log.info(f"relevant table names: {table_of_interest}") # 2. 验证表名格式 pairs: list[tuple[str, str]] = [] for table_name in table_of_interest: schema, table = table_name.split(".", 1) pairs.append((schema, table)) return pairs
关键点:
AskTable 根据数据库规模,自动选择最优的 Schema Linking 模式:
适用场景:
策略:
async def _naive_link( self, accessible_meta: MetaAdmin, question: Question ) -> MetaContext: """朴素模式:全量传递元数据""" # 只检索值和示例,不过滤表 values = await self._retrieve_values(question.keywords or []) pairs = await self._retrieve_examples(question.specification) entities = _merge_values_fields(values) _add_context_to_meta(accessible_meta, entities) return {"meta": accessible_meta, "training_pairs": pairs}
优势:
适用场景:
策略:
async def _rag_link( self, accessible_meta: MetaAdmin, question: Question ) -> MetaContext: """RAG 模式:向量检索 + 全文搜索""" # 1. 检索字段、值和示例 fields = await self._retrieve_fields(question.subqueries or []) values = await self._retrieve_values(question.keywords or []) pairs = await self._retrieve_examples(question.specification) # 2. 合并实体 examples = _training_pair_to_entities(pairs, self.ds.dialect) entities = _merge_values_fields(values + fields + examples) _add_context_to_meta(accessible_meta, entities) # 3. 提取命中的表 hit_table_names: set[tuple[str, str]] = set( [(e["schema_name"], e["table_name"]) for e in entities] ) # 4. 如果命中表过多,使用 LLM Rerank if len(hit_table_names) > 3: fields_full_names = _get_field_full_names_from_entities(entities) hit_fields = accessible_meta.filter_fields_by_names( [convert_full_name_to_tuple(f) for f in fields_full_names] ) table_of_interest = await self._pick_tables( hit_fields, question.specification, pairs ) meta = accessible_meta.filter_tables_by_names(table_of_interest) else: meta = accessible_meta.filter_tables_by_names(list(hit_table_names)) return {"meta": meta, "training_pairs": pairs}
优势:
适用场景:
策略:
async def _reasoning_link( self, accessible_meta: MetaAdmin, question: Question ) -> MetaContext: """推理模式:LLM 主导的表选择""" # 1. 检索值和示例 values = await self._retrieve_values(question.keywords or []) pairs = await self._retrieve_examples(question.specification) # 2. 注入上下文 entities = _merge_values_fields(values) _add_context_to_meta(accessible_meta, entities) # 3. 让 LLM 选择相关表 table_of_interest = await self._pick_tables( accessible_meta, question.specification, pairs ) meta = accessible_meta.filter_tables_by_names(table_of_interest) return {"meta": meta, "training_pairs": pairs}
优势:
async def link(self, question: Question) -> MetaContext: """自动选择最优模式""" accessible_meta = self._get_accessible_meta() if config.at_schema_linking_mode == SchemaLinkingMode.auto: if accessible_meta.table_count <= 3 and accessible_meta.field_count <= 100: return await self._naive_link(accessible_meta, question) elif accessible_meta.table_count <= 7 and accessible_meta.field_count <= 300: return await self._reasoning_link(accessible_meta, question) else: return await self._rag_link(accessible_meta, question) elif config.at_schema_linking_mode == SchemaLinkingMode.naive: return await self._naive_link(accessible_meta, question) elif config.at_schema_linking_mode == SchemaLinkingMode.rag: return await self._rag_link(accessible_meta, question) elif config.at_schema_linking_mode == SchemaLinkingMode.reasoning: return await self._reasoning_link(accessible_meta, question)
HNSW 索引:
# Qdrant 配置 collection_config = { "vectors": { "size": 1536, # OpenAI embedding 维度 "distance": "Cosine", }, "hnsw_config": { "m": 16, # 连接数 "ef_construct": 100, # 构建时的搜索深度 }, }
效果:
# 批量检索多个查询 query_vectors = await self.embedding_model.encode(queries) results = await asyncio.gather(*[ self.qdrant_client.search( collection_name=f"meta_{self.id}", query_vector=vec, limit=20, ) for vec in query_vectors ])
效果:
# 缓存 Embedding 结果 @lru_cache(maxsize=1000) async def get_embedding(text: str) -> list[float]: return await embedding_model.encode(text)
效果:
# 问题 question = "有多少个用户?" # Schema Linking 结果 meta = { "tables": [ { "name": "users", "fields": [ {"name": "id", "type": "INT"}, {"name": "name", "type": "VARCHAR"}, ] } ] } # 生成的 SQL sql = "SELECT COUNT(*) FROM users"
# 问题 question = "上个月销售额最高的产品是什么?" # Schema Linking 结果 meta = { "tables": [ { "name": "orders", "fields": [ {"name": "product_id", "type": "INT"}, {"name": "amount", "type": "DECIMAL", "description": "销售额(e.g. \"1000.00\",\"2500.50\")"}, {"name": "created_at", "type": "TIMESTAMP"}, ] }, { "name": "products", "fields": [ {"name": "id", "type": "INT"}, {"name": "name", "type": "VARCHAR", "description": "产品名称(e.g. \"iPhone\",\"MacBook\")"}, ] } ], "training_pairs": [ { "question": "本月销售额最高的商品?", "sql": "SELECT p.name, SUM(o.amount) FROM orders o JOIN products p ON o.product_id = p.id WHERE MONTH(o.created_at) = MONTH(NOW()) GROUP BY p.name ORDER BY SUM(o.amount) DESC LIMIT 1" } ] } # 生成的 SQL sql = """ SELECT p.name, SUM(o.amount) as total_amount FROM orders o JOIN products p ON o.product_id = p.id WHERE o.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH) GROUP BY p.name ORDER BY total_amount DESC LIMIT 1 """
AskTable 的 Schema Linking 引擎,通过 向量检索 + 全文搜索 + Training Pair 复用 的混合策略,实现了:
✅ 高准确率:多模态检索提升召回率和精准度 ✅ 低延迟:HNSW 索引 + 批量检索 < 50ms ✅ 自适应:根据数据库规模自动选择最优模式 ✅ 可扩展:支持大型数据库(1000+ 表)
相关阅读:
技术交流: