AskTable

AskTable Schema Linking:向量检索 + 全文搜索的混合策略实现精准 Text-to-SQL

AskTable 团队
AskTable 团队 2026年3月5日

在 Text-to-SQL 场景中,最大的挑战不是生成 SQL 语法,而是如何让 AI 找到正确的表和字段。一个企业数据库可能有数百张表、数千个字段,如何从用户的自然语言问题中精准定位到相关的元数据?

AskTable 的 Schema Linking 引擎,通过 向量检索 + 全文搜索 + Training Pair 复用 的混合策略,实现了高准确率的元数据检索。

本文将深入剖析这套引擎的设计与实现。


一、什么是 Schema Linking?

1. Text-to-SQL 的核心挑战

问题:用户问"上个月销售额最高的产品是什么?"

数据库

AI 需要回答

Schema Linking 的任务:从海量元数据中,找出与问题相关的表和字段。

2. 传统方案的局限

全量传递

# 将所有表和字段传递给 LLM
prompt = f"数据库有以下表:{all_tables}\n问题:{question}"

问题

关键词匹配

# 简单的字符串匹配
relevant_tables = [t for t in all_tables if keyword in t.name]

问题


二、AskTable 的混合检索策略

1. 三层检索架构

加载图表中...

2. 核心组件

Qdrant(向量检索)

Meilisearch(全文搜索)

Training Pair(示例复用)


三、核心实现:深入源码

1. 问题改写:提取关键词和子查询

async def _rewrite_question(self, question: Question) -> None:
    """将用户问题改写为关键词和子查询"""
    if not question.subqueries:
        response = await prompt_generate(
            "query.extract_keywords_from_question",
            QUESTION=question.text,
            SPECIFICATION=question.specification,
            EVIDENCE=question.evidence,
        )
        question.keywords = response["keywords"]
        question.subqueries = response["subqueries"]
        log.info(f"extracted keywords: {question.keywords}")

示例

# 输入
question = "上个月销售额最高的产品是什么?"

# 输出
keywords = ["销售额", "产品", "上个月"]
subqueries = [
    "查询销售记录表",
    "按产品分组统计销售额",
    "筛选上个月的数据"
]

关键点

2. 向量检索:查找相关字段

async def _retrieve_fields(self, queries: list[str]) -> list[RetrievedMetaEntity]:
    """通过向量检索查找相关字段"""
    if not queries:
        log.warning("No subqueries provided, skipping field retrieval")
        return []
    return await self.ds.retrieve_fields_by_question(queries)

Qdrant 检索实现

async def retrieve_fields_by_question(
    self, queries: list[str]
) -> list[RetrievedMetaEntity]:
    """向量检索字段"""
    # 1. 将查询转换为向量
    query_vectors = await self.embedding_model.encode(queries)

    # 2. 在 Qdrant 中搜索
    results = await self.qdrant_client.search(
        collection_name=f"meta_{self.id}",
        query_vector=query_vectors[0],
        limit=20,
        score_threshold=0.7,
    )

    # 3. 返回检索结果
    return [
        {
            "id": hit.id,
            "payload": hit.payload,
            "score": hit.score,
        }
        for hit in results
    ]

关键点

3. 全文搜索:查找字段值

async def _retrieve_values(self, keywords: list[str]) -> list[RetrievedMetaEntity]:
    """通过全文搜索查找字段值"""
    if not config.aisearch_host or not config.aisearch_master_key:
        log.warning("Value index is not enabled, skipping value retrieval")
    elif not keywords:
        log.warning("No keywords provided, skipping value retrieval")
    else:
        values = await self.ds.retrieve_values_by_question(keywords)
        return values
    return []

Meilisearch 检索实现

async def retrieve_values_by_question(
    self, keywords: list[str]
) -> list[RetrievedMetaEntity]:
    """全文搜索字段值"""
    # 1. 构建搜索查询
    query = " ".join(keywords)

    # 2. 在 Meilisearch 中搜索
    results = await self.meilisearch_client.index(f"values_{self.id}").search(
        query,
        limit=50,
        attributesToRetrieve=["schema_name", "table_name", "field_name", "value"],
    )

    # 3. 返回检索结果
    return [
        {
            "id": hit["id"],
            "payload": {
                "schema_name": hit["schema_name"],
                "table_name": hit["table_name"],
                "field_name": hit["field_name"],
                "value": hit["value"],
                "type": "value",
            },
            "score": hit["_rankingScore"],
        }
        for hit in results["hits"]
    ]

关键点

4. Training Pair 检索:复用历史示例

async def _retrieve_examples(self, query: str) -> list[TrainingPair]:
    """检索相似的历史问题-SQL 对"""
    translation_examples = await retrieve_training_pairs(
        datasource_id=self.ds.id,
        query=query,
        role_id=self.role.id if self.role else None,
    )
    return translation_examples

Training Pair 存储结构

class TrainingPair(TypedDict):
    question: str  # 历史问题
    sql: str  # 对应的 SQL
    score: float  # 相似度分数

示例

# 当前问题
question = "上个月销售额最高的产品是什么?"

# 检索到的相似问题
training_pairs = [
    {
        "question": "本月销售额最高的商品是哪个?",
        "sql": "SELECT product_name, SUM(amount) FROM sales WHERE month = CURRENT_MONTH GROUP BY product_name ORDER BY SUM(amount) DESC LIMIT 1",
        "score": 0.92,
    },
    {
        "question": "去年销量最好的产品?",
        "sql": "SELECT product_id, COUNT(*) FROM orders WHERE year = LAST_YEAR GROUP BY product_id ORDER BY COUNT(*) DESC LIMIT 1",
        "score": 0.85,
    },
]

关键点

5. 实体合并:去重和聚合

def _merge_values_fields(hits: list[RetrievedMetaEntity]) -> list[MetaEntity]:
    """合并字段和值检索结果"""
    fields_buckets: dict[tuple, set] = {}

    for hit in hits:
        index = (
            hit["payload"]["schema_name"],
            hit["payload"]["table_name"],
            hit["payload"]["field_name"],
        )
        if not fields_buckets.get(index):
            fields_buckets[index] = set()
        bucket = fields_buckets[index]

        if hit["payload"]["type"] == "value":
            bucket.add(hit["payload"]["value"])

    fields_list: list[MetaEntity] = []
    for index, values in fields_buckets.items():
        fields_list.append(
            {
                "schema_name": index[0],
                "table_name": index[1],
                "field_name": index[2],
                "sample_values": list(values),
            }
        )
    return fields_list

关键点

6. 上下文注入:增强元数据描述

def _add_context_to_meta(meta: MetaAdmin, entities: list[MetaEntity]):
    """将检索到的字段值注入到元数据描述中"""
    for entity in entities:
        if schema := meta.schemas.get(entity["schema_name"]):
            if table := schema.tables.get(entity["table_name"]):
                if field := table.fields.get(entity["field_name"]):
                    values = [f'"{v}"' for v in entity["sample_values"]]
                    if values:
                        if field.curr_desc:
                            field.curr_desc += f"(e.g. {','.join(values)})"
                        else:
                            field.curr_desc = f"(e.g. {','.join(values)})"

效果

# 原始元数据
field = {
    "name": "status",
    "type": "VARCHAR",
    "description": "订单状态"
}

# 注入上下文后
field = {
    "name": "status",
    "type": "VARCHAR",
    "description": "订单状态(e.g. \"已完成\",\"待支付\",\"已取消\")"
}

关键点

7. 表选择:LLM Rerank

async def _pick_tables(
    self,
    meta_candidate: MetaAdmin,
    specification: str,
    training_pairs: list[TrainingPair],
) -> list[tuple[str, str]]:
    """通过 LLM 重新排序和选择相关表"""
    # 1. 让 LLM 选择最相关的表
    table_of_interest_ = await prompt_generate(
        "query.select_tables_by_question",
        meta_data=meta_candidate.to_markdown(),
        question=specification,
        translation_examples=dict_to_markdown(training_pairs),
    )
    table_of_interest = table_of_interest_["table_names"]

    if not table_of_interest:
        raise errors.NoDataToQuery(params={"message": "No data to query"})

    log.info(f"relevant table names: {table_of_interest}")

    # 2. 验证表名格式
    pairs: list[tuple[str, str]] = []
    for table_name in table_of_interest:
        schema, table = table_name.split(".", 1)
        pairs.append((schema, table))

    return pairs

关键点


四、多模式自适应策略

AskTable 根据数据库规模,自动选择最优的 Schema Linking 模式:

1. Naive Mode(朴素模式)

适用场景

策略

async def _naive_link(
    self, accessible_meta: MetaAdmin, question: Question
) -> MetaContext:
    """朴素模式:全量传递元数据"""
    # 只检索值和示例,不过滤表
    values = await self._retrieve_values(question.keywords or [])
    pairs = await self._retrieve_examples(question.specification)

    entities = _merge_values_fields(values)
    _add_context_to_meta(accessible_meta, entities)

    return {"meta": accessible_meta, "training_pairs": pairs}

优势

2. RAG Mode(检索增强模式)

适用场景

策略

async def _rag_link(
    self, accessible_meta: MetaAdmin, question: Question
) -> MetaContext:
    """RAG 模式:向量检索 + 全文搜索"""
    # 1. 检索字段、值和示例
    fields = await self._retrieve_fields(question.subqueries or [])
    values = await self._retrieve_values(question.keywords or [])
    pairs = await self._retrieve_examples(question.specification)

    # 2. 合并实体
    examples = _training_pair_to_entities(pairs, self.ds.dialect)
    entities = _merge_values_fields(values + fields + examples)
    _add_context_to_meta(accessible_meta, entities)

    # 3. 提取命中的表
    hit_table_names: set[tuple[str, str]] = set(
        [(e["schema_name"], e["table_name"]) for e in entities]
    )

    # 4. 如果命中表过多,使用 LLM Rerank
    if len(hit_table_names) > 3:
        fields_full_names = _get_field_full_names_from_entities(entities)
        hit_fields = accessible_meta.filter_fields_by_names(
            [convert_full_name_to_tuple(f) for f in fields_full_names]
        )
        table_of_interest = await self._pick_tables(
            hit_fields, question.specification, pairs
        )
        meta = accessible_meta.filter_tables_by_names(table_of_interest)
    else:
        meta = accessible_meta.filter_tables_by_names(list(hit_table_names))

    return {"meta": meta, "training_pairs": pairs}

优势

3. Reasoning Mode(推理模式)

适用场景

策略

async def _reasoning_link(
    self, accessible_meta: MetaAdmin, question: Question
) -> MetaContext:
    """推理模式:LLM 主导的表选择"""
    # 1. 检索值和示例
    values = await self._retrieve_values(question.keywords or [])
    pairs = await self._retrieve_examples(question.specification)

    # 2. 注入上下文
    entities = _merge_values_fields(values)
    _add_context_to_meta(accessible_meta, entities)

    # 3. 让 LLM 选择相关表
    table_of_interest = await self._pick_tables(
        accessible_meta, question.specification, pairs
    )
    meta = accessible_meta.filter_tables_by_names(table_of_interest)

    return {"meta": meta, "training_pairs": pairs}

优势

4. Auto Mode(自动模式)

async def link(self, question: Question) -> MetaContext:
    """自动选择最优模式"""
    accessible_meta = self._get_accessible_meta()

    if config.at_schema_linking_mode == SchemaLinkingMode.auto:
        if accessible_meta.table_count <= 3 and accessible_meta.field_count <= 100:
            return await self._naive_link(accessible_meta, question)
        elif accessible_meta.table_count <= 7 and accessible_meta.field_count <= 300:
            return await self._reasoning_link(accessible_meta, question)
        else:
            return await self._rag_link(accessible_meta, question)
    elif config.at_schema_linking_mode == SchemaLinkingMode.naive:
        return await self._naive_link(accessible_meta, question)
    elif config.at_schema_linking_mode == SchemaLinkingMode.rag:
        return await self._rag_link(accessible_meta, question)
    elif config.at_schema_linking_mode == SchemaLinkingMode.reasoning:
        return await self._reasoning_link(accessible_meta, question)

五、性能优化

1. 向量索引优化

HNSW 索引

# Qdrant 配置
collection_config = {
    "vectors": {
        "size": 1536,  # OpenAI embedding 维度
        "distance": "Cosine",
    },
    "hnsw_config": {
        "m": 16,  # 连接数
        "ef_construct": 100,  # 构建时的搜索深度
    },
}

效果

2. 批量检索

# 批量检索多个查询
query_vectors = await self.embedding_model.encode(queries)
results = await asyncio.gather(*[
    self.qdrant_client.search(
        collection_name=f"meta_{self.id}",
        query_vector=vec,
        limit=20,
    )
    for vec in query_vectors
])

效果

3. 缓存策略

# 缓存 Embedding 结果
@lru_cache(maxsize=1000)
async def get_embedding(text: str) -> list[float]:
    return await embedding_model.encode(text)

效果


六、实战案例

案例 1:简单查询

# 问题
question = "有多少个用户?"

# Schema Linking 结果
meta = {
    "tables": [
        {
            "name": "users",
            "fields": [
                {"name": "id", "type": "INT"},
                {"name": "name", "type": "VARCHAR"},
            ]
        }
    ]
}

# 生成的 SQL
sql = "SELECT COUNT(*) FROM users"

案例 2:复杂查询

# 问题
question = "上个月销售额最高的产品是什么?"

# Schema Linking 结果
meta = {
    "tables": [
        {
            "name": "orders",
            "fields": [
                {"name": "product_id", "type": "INT"},
                {"name": "amount", "type": "DECIMAL", "description": "销售额(e.g. \"1000.00\",\"2500.50\")"},
                {"name": "created_at", "type": "TIMESTAMP"},
            ]
        },
        {
            "name": "products",
            "fields": [
                {"name": "id", "type": "INT"},
                {"name": "name", "type": "VARCHAR", "description": "产品名称(e.g. \"iPhone\",\"MacBook\")"},
            ]
        }
    ],
    "training_pairs": [
        {
            "question": "本月销售额最高的商品?",
            "sql": "SELECT p.name, SUM(o.amount) FROM orders o JOIN products p ON o.product_id = p.id WHERE MONTH(o.created_at) = MONTH(NOW()) GROUP BY p.name ORDER BY SUM(o.amount) DESC LIMIT 1"
        }
    ]
}

# 生成的 SQL
sql = """
SELECT p.name, SUM(o.amount) as total_amount
FROM orders o
JOIN products p ON o.product_id = p.id
WHERE o.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH)
GROUP BY p.name
ORDER BY total_amount DESC
LIMIT 1
"""

七、总结与展望

AskTable 的 Schema Linking 引擎,通过 向量检索 + 全文搜索 + Training Pair 复用 的混合策略,实现了:

高准确率:多模态检索提升召回率和精准度 ✅ 低延迟:HNSW 索引 + 批量检索 < 50ms ✅ 自适应:根据数据库规模自动选择最优模式 ✅ 可扩展:支持大型数据库(1000+ 表)

未来优化方向

  1. 图检索:利用表间关系(外键)优化检索
  2. 强化学习:根据用户反馈优化检索策略
  3. 多语言支持:支持中英文混合查询
  4. 实时更新:元数据变更后自动更新索引

相关阅读

技术交流