引入依赖
<dependencyManagement><dependencies><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-bom</artifactId><version>1.0.0-SNAPSHOT</version><type>pom</type><scope>import</scope></dependency></dependencies></dependencyManagement><dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-openai-spring-boot-starter</artifactId></dependency><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-redis-store</artifactId></dependency><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>5.1.0</version></dependency><dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-spring-boot3-starter</artifactId><version>3.5.7</version></dependency><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>8.0.32</version></dependency><dependency><groupId>org.jfree</groupId><artifactId>jfreechart</artifactId><version>1.5.3</version></dependency></dependencies>
代码
package com.qjc.demo.controller;import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.lang3.StringUtils;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtils;
import org.jfree.chart.JFreeChart;
import org.jfree.data.general.DefaultPieDataset;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;import java.io.IOException;
import java.io.OutputStream;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Controller
public class SqlController {@Resourceprivate ChatModel chatModel;@Resourceprivate JdbcTemplate jdbcTemplate;private final String FILTER_INSTRUCTION = """你需要根据指定的Input从Instruction中筛选出最相关的表信息(可能是单个表或多个表),首先,我将给你展示一个示例,Instruction后面跟着Input和对应的Response,然后,我会给你一个新的Instruction和新的Input,你需要生成一个新的Response来完成任务。### Example1 Instruction:job(id, name, age), user(id, name, age), student(id, name, age, info)### Example1 Input:Find the age of student table### Example1 Response:student(id, name, age, info)###New Instruction:{instruction}###New Input:{input}###New Response:""";private final String GENERATE_INSTRUCTION = """你扮演一个SQL终端,您只需要返回SQL命令给我,而不需要返回其他任何字符。下面是一个描述任务的Instruction,返回适当的结果完成Input对应的请求.###Instruction:{instruction}###Input:{input}###Response:""";@GetMapping("/chat")public void chat(@RequestParam("query") String query, HttpServletResponse response) throws SQLException, IOException {Map<String, List<String>> tableInfo = getTableInfo();List<String> tableInfoList = tableInfo.entrySet().stream().map(entry -> String.format("%s(%s)", entry.getKey(), StringUtils.join(entry.getValue(), ","))).toList();String tableInfoPrompt = StringUtils.join(tableInfoList, ",");PromptTemplate filtePromptTemplate = new PromptTemplate(FILTER_INSTRUCTION);filtePromptTemplate.add("instruction", tableInfoPrompt);filtePromptTemplate.add("input", query);String filterPrompt = filtePromptTemplate.render();String filterResult = chatModel.call(filterPrompt);PromptTemplate generatePromptTemplate = new PromptTemplate(GENERATE_INSTRUCTION);generatePromptTemplate.add("instruction", filterResult);generatePromptTemplate.add("input", query);String generatePrompt = generatePromptTemplate.render();String sql = chatModel.call(generatePrompt);sql = sql.replace("```sql", "");sql = sql.replace("```", "");System.out.println(sql);List<Map<String, Object>> maps = jdbcTemplate.queryForList(sql);DefaultPieDataset dataset = new DefaultPieDataset();for (Map<String, Object> map : maps) {Object[] values = map.values().toArray();dataset.setValue(values[0].toString(), Integer.valueOf(values[1].toString()));}JFreeChart chart = ChartFactory.createPieChart("统计结果", dataset, false,true,true);response.setContentType("image/png");OutputStream out = response.getOutputStream();ChartUtils.writeChartAsPNG(out, chart, 800, 600);out.flush();}public Map<String, List<String>> getTableInfo() throws SQLException {DatabaseMetaData metaData = jdbcTemplate.getDataSource().getConnection().getMetaData();ResultSet tables = metaData.getTables(null, null, "%", new String[]{"TABLE"});Map<String, List<String>> result = new HashMap<>();while (tables.next()) {String tableName = tables.getString("TABLE_NAME");ResultSet columns = metaData.getColumns(null, null, tableName, null);ArrayList<String> columnNames = new ArrayList<>();while (columns.next()) {String columnName = columns.getString("COLUMN_NAME");String remarks = columns.getString("REMARKS");columnNames.add(String.format("%s(%s)", columnName, remarks));}result.put(tableName, columnNames);}return result;}
}