AskTable

图表改进 Agent - 基于上下文的迭代优化

AskTable 团队
AskTable 团队 2026年3月4日

图表改进 Agent - 基于上下文的迭代优化

在数据可视化场景中,用户往往需要多次调整图表才能达到理想效果。传统的做法是重新生成整个图表,但这会丢失之前的上下文和用户意图。AskTable 的 ChartImprovementAgent 采用了一种更智能的方法:保留原始上下文,理解改进意图,迭代优化图表

问题背景

用户在使用 AI 生成图表时,常见的改进需求包括:

如果每次都重新生成,LLM 可能会:

ChartImprovementAgent 架构

加载图表中...

核心设计

1. 上下文保留

ChartImprovementAgent 在初始化时接收完整的上下文信息:

class ChartImprovementAgent:
    def __init__(
        self,
        parent_nodes_context: list[dict],  # 父节点数据
        original_question: str,             # 原始问题
        original_code: str,                 # 原始 JSX 代码
        improvement_request: str,           # 改进请求
    ):
        self.parent_nodes_context = parent_nodes_context
        self.original_question = original_question
        self.original_code = original_code
        self.improvement_request = improvement_request

父节点上下文包含:

parent_nodes_context = [
    {
        "id": "node_123",
        "sql": "SELECT region, SUM(sales) as total FROM orders GROUP BY region",
        "description": "各地区销售汇总",
        "dataframe": {
            "columns": ["region", "total"],
            "sample_data": [
                {"region": "华东", "total": 1000000},
                {"region": "华北", "total": 800000},
            ]
        }
    }
]

2. 系统 Prompt 构建

将上下文信息格式化为系统 Prompt:

# 格式化父节点信息
parent_info_parts = []
for i, node in enumerate(parent_nodes_context, 1):
    parent_info_parts.append(f"Node {i} (ID: {node['id']}):")
    if node.get("description"):
        parent_info_parts.append(f"  Description: {node['description']}")
    if node.get("sql"):
        parent_info_parts.append(f"  SQL: {node['sql']}")
    if node.get("dataframe"):
        df = node["dataframe"]
        parent_info_parts.append(f"  Columns: {', '.join(df.get('columns', []))}")
        if df.get("sample_data"):
            parent_info_parts.append(f"  Sample rows: {len(df['sample_data'])}")
    parent_info_parts.append("")

formatted_parent_info = "\n".join(parent_info_parts)

# 构建系统 Prompt
system_prompt = get_prompt("agent/canvas/edit_chart").compile(
    formatted_parent_info=formatted_parent_info,
    original_question=original_question,
    original_code=original_code,
)

3. 改进工具

Agent 提供

submit_improved_chart
工具,用于提交改进后的代码:

def submit_improved_chart(
    self,
    question: str = Field(
        ...,
        description="The rewritten question that naturally incorporates both original requirement and the improvement",
    ),
    description: str = Field(
        ...,
        description="Brief description of the improvements made (1-2 sentences), or error reason if code is None",
    ),
    code: str | None = Field(
        None,
        description="The improved JSX code for the chart component. Set to None if improvement failed.",
    ),
) -> str:
    if code is not None:
        try:
            # 编译 JSX 代码
            self.compiled_code = compile_jsx(code)
            self.source_code = code
            log.info("Improved chart code compiled successfully")
        except Exception as e:
            log.error(f"Failed to compile improved chart code: {str(e)}")
            raise ValueError(f"Code compilation failed: {str(e)}")
    else:
        # 改进失败,记录原因
        self.compiled_code = None
        self.source_code = None
        log.warning(f"Chart improvement failed: {description}")

    self.submitted_description = description
    self.submitted_question = question

    status_msg = "error" if code is None else "success"
    return json.dumps({
        "question": question,
        "description": description,
        "status": status_msg,
        "has_code": code is not None,
    })

4. 问题重写

Agent 需要将原始问题和改进请求合并为新的问题描述:

示例

这样可以保持问题的完整性,便于后续理解和维护。

实际应用示例

示例 1: 图表类型转换

原始代码

<code>
import { BarChart } from '@/components/charts';

function RegionSales() {
  const data = load_dataframe('df_region_sales');

  return (
    <BarChart
      data={data}
      xField="region"
      yField="total"
      title="各地区销售额"
    />
  );
}
</code>

改进请求:「改成折线图」

改进后代码

<code>
import { LineChart } from '@/components/charts';

function RegionSales() {
  const data = load_dataframe('df_region_sales');

  return (
    <LineChart
      data={data}
      xField="region"
      yField="total"
      title="各地区销售额"
      smooth={true}
    />
  );
}
</code>

示例 2: 数据筛选

改进请求:「只显示销售额前 5 的地区」

改进后代码

<code>
import { BarChart } from '@/components/charts';

function RegionSales() {
  const data = load_dataframe('df_region_sales');

  // 按销售额降序排列,取前 5
  const topData = data
    .sort((a, b) => b.total - a.total)
    .slice(0, 5);

  return (
    <BarChart
      data={topData}
      xField="region"
      yField="total"
      title="销售额 Top 5 地区"
    />
  );
}
</code>

示例 3: 样式优化

改进请求:「使用蓝色系配色,图例放在底部」

改进后代码

<code>
import { BarChart } from '@/components/charts';

function RegionSales() {
  const data = load_dataframe('df_region_sales');

  return (
    <BarChart
      data={data}
      xField="region"
      yField="total"
      title="各地区销售额"
      color="#1890ff"
      legend={{
        position: 'bottom'
      }}
    />
  );
}
</code>

示例 4: 复杂改进

改进请求:「添加平均线,标注最高值,使用渐变色」

改进后代码

<code>
import { BarChart } from '@/components/charts';

function RegionSales() {
  const data = load_dataframe('df_region_sales');

  // 计算平均值
  const average = data.reduce((sum, item) => sum + item.total, 0) / data.length;

  // 找到最高值
  const maxItem = data.reduce((max, item) => item.total > max.total ? item : max);

  return (
    <div>
      <BarChart
        data={data}
        xField="region"
        yField="total"
        title="各地区销售额"
        color={{
          type: 'gradient',
          colors: ['#1890ff', '#52c41a']
        }}
        annotations={[
          {
            type: 'line',
            start: ['min', average],
            end: ['max', average],
            style: { stroke: '#ff4d4f', lineDash: [4, 4] },
            text: { content: `平均值: ${average.toFixed(0)}`, position: 'end' }
          },
          {
            type: 'text',
            position: [maxItem.region, maxItem.total],
            content: `最高: ${maxItem.total}`,
            style: { fill: '#ff4d4f', fontWeight: 'bold' }
          }
        ]}
      />
    </div>
  );
}
</code>

错误处理

1. 编译失败

如果改进后的代码无法编译,Agent 会返回错误信息:

try:
    self.compiled_code = compile_jsx(code)
except Exception as e:
    raise ValueError(f"Code compilation failed: {str(e)}")

2. 改进不可行

如果 LLM 判断改进请求无法实现,可以返回

code=None

def submit_improved_chart(
    self,
    question: str,
    description: str,
    code: str | None = None,  # None 表示改进失败
):
    if code is None:
        # 记录失败原因
        self.compiled_code = None
        self.source_code = None
        log.warning(f"Chart improvement failed: {description}")

示例

3. 数据源验证

确保改进后的代码仍然引用正确的数据源:

# 提取 load_dataframe 引用
load_df_pattern = r"load_dataframe\(\s*['\"]( df_[A-Za-z0-9]+)['\"]\s*\)"
referenced_dataframes = re.findall(load_df_pattern, code)

# 验证数据源是否存在
missing_ids = set(referenced_dataframes) - set(self.data_workspace.keys())
if missing_ids:
    raise ValueError(f"Referenced dataframes {missing_ids} are not in the data workspace")

与 ChartNodeAgent 的对比

特性ChartNodeAgentChartImprovementAgent
用途从零生成图表改进现有图表
输入用户问题 + 数据原始代码 + 改进请求
上下文父节点数据父节点数据 + 原始代码 + 原始问题
输出新的 JSX 代码改进后的 JSX 代码 + 重写的问题
问题描述用户原始问题合并原始问题和改进请求

用户体验优化

1. 渐进式改进

用户可以多次迭代改进图表:

用户: 展示各地区销售额
→ 生成柱状图

用户: 改成折线图
→ 改进为折线图

用户: 只显示前 5 名
→ 添加数据筛选

用户: 使用蓝色系配色
→ 调整颜色方案

2. 上下文保持

每次改进都保留之前的所有优化:

// 第一次改进:改成折线图
<LineChart ... />

// 第二次改进:只显示前 5 名
const topData = data.slice(0, 5);
<LineChart data={topData} ... />

// 第三次改进:使用蓝色系配色
const topData = data.slice(0, 5);
<LineChart data={topData} color="#1890ff" ... />

3. 失败降级

如果改进失败,保留原始图表:

def get_result(self) -> dict:
    if self.source_code is None:
        return {
            "code": None,
            "compiled_code": None,
            "question": self.submitted_question or self.original_question,
            "description": self.submitted_description,
            "status": "error",
            "error": self.submitted_description,
        }

    return {
        "code": self.source_code,
        "compiled_code": self.compiled_code,
        "question": self.submitted_question,
        "description": self.submitted_description,
        "status": "success",
        "error": None,
    }

性能优化

1. 增量编译

只编译改进后的代码,不重新编译整个项目:

self.compiled_code = compile_jsx(code)  # 单文件编译

2. 缓存复用

复用父节点数据,避免重复查询:

# 父节点数据已经包含 DataFrame
parent_nodes_context = [
    {
        "id": "node_123",
        "dataframe": cached_dataframe  # 复用缓存
    }
]

3. 并行处理

多个改进请求可以并行处理:

# 并行处理多个改进
agents = [
    ChartImprovementAgent(..., improvement_request="改成折线图"),
    ChartImprovementAgent(..., improvement_request="只显示前 5 名"),
]

results = await asyncio.gather(*[agent.run() for agent in agents])

总结

AskTable 的 ChartImprovementAgent 通过以下技术实现了智能的图表迭代优化:

  1. 上下文保留:保存原始问题、代码和数据源
  2. 意图理解:LLM 理解用户的改进需求
  3. 问题重写:合并原始问题和改进请求
  4. 代码验证:JSX 编译和数据源验证
  5. 错误降级:改进失败时保留原始图表

这种设计不仅提升了用户体验,还保证了图表改进的可靠性和一致性,是 AI 驱动数据可视化系统的重要组成部分。

相关资源