
企业微信

飞书
选择您喜欢的方式加入群聊

扫码添加咨询专家
在数据可视化场景中,用户往往需要多次调整图表才能达到理想效果。传统的做法是重新生成整个图表,但这会丢失之前的上下文和用户意图。AskTable 的 ChartImprovementAgent 采用了一种更智能的方法:保留原始上下文,理解改进意图,迭代优化图表。
用户在使用 AI 生成图表时,常见的改进需求包括:
如果每次都重新生成,LLM 可能会:
加载图表中...
ChartImprovementAgent 在初始化时接收完整的上下文信息:
class ChartImprovementAgent: def __init__( self, parent_nodes_context: list[dict], # 父节点数据 original_question: str, # 原始问题 original_code: str, # 原始 JSX 代码 improvement_request: str, # 改进请求 ): self.parent_nodes_context = parent_nodes_context self.original_question = original_question self.original_code = original_code self.improvement_request = improvement_request
父节点上下文包含:
parent_nodes_context = [ { "id": "node_123", "sql": "SELECT region, SUM(sales) as total FROM orders GROUP BY region", "description": "各地区销售汇总", "dataframe": { "columns": ["region", "total"], "sample_data": [ {"region": "华东", "total": 1000000}, {"region": "华北", "total": 800000}, ] } } ]
将上下文信息格式化为系统 Prompt:
# 格式化父节点信息 parent_info_parts = [] for i, node in enumerate(parent_nodes_context, 1): parent_info_parts.append(f"Node {i} (ID: {node['id']}):") if node.get("description"): parent_info_parts.append(f" Description: {node['description']}") if node.get("sql"): parent_info_parts.append(f" SQL: {node['sql']}") if node.get("dataframe"): df = node["dataframe"] parent_info_parts.append(f" Columns: {', '.join(df.get('columns', []))}") if df.get("sample_data"): parent_info_parts.append(f" Sample rows: {len(df['sample_data'])}") parent_info_parts.append("") formatted_parent_info = "\n".join(parent_info_parts) # 构建系统 Prompt system_prompt = get_prompt("agent/canvas/edit_chart").compile( formatted_parent_info=formatted_parent_info, original_question=original_question, original_code=original_code, )
Agent 提供
submit_improved_chart 工具,用于提交改进后的代码:
def submit_improved_chart( self, question: str = Field( ..., description="The rewritten question that naturally incorporates both original requirement and the improvement", ), description: str = Field( ..., description="Brief description of the improvements made (1-2 sentences), or error reason if code is None", ), code: str | None = Field( None, description="The improved JSX code for the chart component. Set to None if improvement failed.", ), ) -> str: if code is not None: try: # 编译 JSX 代码 self.compiled_code = compile_jsx(code) self.source_code = code log.info("Improved chart code compiled successfully") except Exception as e: log.error(f"Failed to compile improved chart code: {str(e)}") raise ValueError(f"Code compilation failed: {str(e)}") else: # 改进失败,记录原因 self.compiled_code = None self.source_code = None log.warning(f"Chart improvement failed: {description}") self.submitted_description = description self.submitted_question = question status_msg = "error" if code is None else "success" return json.dumps({ "question": question, "description": description, "status": status_msg, "has_code": code is not None, })
Agent 需要将原始问题和改进请求合并为新的问题描述:
示例:
这样可以保持问题的完整性,便于后续理解和维护。
原始代码:
<code> import { BarChart } from '@/components/charts'; function RegionSales() { const data = load_dataframe('df_region_sales'); return ( <BarChart data={data} xField="region" yField="total" title="各地区销售额" /> ); } </code>
改进请求:「改成折线图」
改进后代码:
<code> import { LineChart } from '@/components/charts'; function RegionSales() { const data = load_dataframe('df_region_sales'); return ( <LineChart data={data} xField="region" yField="total" title="各地区销售额" smooth={true} /> ); } </code>
改进请求:「只显示销售额前 5 的地区」
改进后代码:
<code> import { BarChart } from '@/components/charts'; function RegionSales() { const data = load_dataframe('df_region_sales'); // 按销售额降序排列,取前 5 const topData = data .sort((a, b) => b.total - a.total) .slice(0, 5); return ( <BarChart data={topData} xField="region" yField="total" title="销售额 Top 5 地区" /> ); } </code>
改进请求:「使用蓝色系配色,图例放在底部」
改进后代码:
<code> import { BarChart } from '@/components/charts'; function RegionSales() { const data = load_dataframe('df_region_sales'); return ( <BarChart data={data} xField="region" yField="total" title="各地区销售额" color="#1890ff" legend={{ position: 'bottom' }} /> ); } </code>
改进请求:「添加平均线,标注最高值,使用渐变色」
改进后代码:
<code> import { BarChart } from '@/components/charts'; function RegionSales() { const data = load_dataframe('df_region_sales'); // 计算平均值 const average = data.reduce((sum, item) => sum + item.total, 0) / data.length; // 找到最高值 const maxItem = data.reduce((max, item) => item.total > max.total ? item : max); return ( <div> <BarChart data={data} xField="region" yField="total" title="各地区销售额" color={{ type: 'gradient', colors: ['#1890ff', '#52c41a'] }} annotations={[ { type: 'line', start: ['min', average], end: ['max', average], style: { stroke: '#ff4d4f', lineDash: [4, 4] }, text: { content: `平均值: ${average.toFixed(0)}`, position: 'end' } }, { type: 'text', position: [maxItem.region, maxItem.total], content: `最高: ${maxItem.total}`, style: { fill: '#ff4d4f', fontWeight: 'bold' } } ]} /> </div> ); } </code>
如果改进后的代码无法编译,Agent 会返回错误信息:
try: self.compiled_code = compile_jsx(code) except Exception as e: raise ValueError(f"Code compilation failed: {str(e)}")
如果 LLM 判断改进请求无法实现,可以返回
code=None:
def submit_improved_chart( self, question: str, description: str, code: str | None = None, # None 表示改进失败 ): if code is None: # 记录失败原因 self.compiled_code = None self.source_code = None log.warning(f"Chart improvement failed: {description}")
示例:
code=None, description="当前图表库不支持 3D 效果"确保改进后的代码仍然引用正确的数据源:
# 提取 load_dataframe 引用 load_df_pattern = r"load_dataframe\(\s*['\"]( df_[A-Za-z0-9]+)['\"]\s*\)" referenced_dataframes = re.findall(load_df_pattern, code) # 验证数据源是否存在 missing_ids = set(referenced_dataframes) - set(self.data_workspace.keys()) if missing_ids: raise ValueError(f"Referenced dataframes {missing_ids} are not in the data workspace")
| 特性 | ChartNodeAgent | ChartImprovementAgent |
|---|---|---|
| 用途 | 从零生成图表 | 改进现有图表 |
| 输入 | 用户问题 + 数据 | 原始代码 + 改进请求 |
| 上下文 | 父节点数据 | 父节点数据 + 原始代码 + 原始问题 |
| 输出 | 新的 JSX 代码 | 改进后的 JSX 代码 + 重写的问题 |
| 问题描述 | 用户原始问题 | 合并原始问题和改进请求 |
用户可以多次迭代改进图表:
用户: 展示各地区销售额 → 生成柱状图 用户: 改成折线图 → 改进为折线图 用户: 只显示前 5 名 → 添加数据筛选 用户: 使用蓝色系配色 → 调整颜色方案
每次改进都保留之前的所有优化:
// 第一次改进:改成折线图 <LineChart ... /> // 第二次改进:只显示前 5 名 const topData = data.slice(0, 5); <LineChart data={topData} ... /> // 第三次改进:使用蓝色系配色 const topData = data.slice(0, 5); <LineChart data={topData} color="#1890ff" ... />
如果改进失败,保留原始图表:
def get_result(self) -> dict: if self.source_code is None: return { "code": None, "compiled_code": None, "question": self.submitted_question or self.original_question, "description": self.submitted_description, "status": "error", "error": self.submitted_description, } return { "code": self.source_code, "compiled_code": self.compiled_code, "question": self.submitted_question, "description": self.submitted_description, "status": "success", "error": None, }
只编译改进后的代码,不重新编译整个项目:
self.compiled_code = compile_jsx(code) # 单文件编译
复用父节点数据,避免重复查询:
# 父节点数据已经包含 DataFrame parent_nodes_context = [ { "id": "node_123", "dataframe": cached_dataframe # 复用缓存 } ]
多个改进请求可以并行处理:
# 并行处理多个改进 agents = [ ChartImprovementAgent(..., improvement_request="改成折线图"), ChartImprovementAgent(..., improvement_request="只显示前 5 名"), ] results = await asyncio.gather(*[agent.run() for agent in agents])
AskTable 的 ChartImprovementAgent 通过以下技术实现了智能的图表迭代优化:
这种设计不仅提升了用户体验,还保证了图表改进的可靠性和一致性,是 AI 驱动数据可视化系统的重要组成部分。