From f7e614dee24bbf5157c2b2b28d5dc86c53a5ebab Mon Sep 17 00:00:00 2001 From: jack ning Date: Tue, 15 Apr 2025 14:37:45 +0800 Subject: [PATCH] update --- .../ai/springai/base/BaseSpringAIService.java | 118 +++++++++++------- 1 file changed, 73 insertions(+), 45 deletions(-) diff --git a/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java b/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java index febd3063cf..7329054efc 100644 --- a/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java +++ b/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java @@ -102,59 +102,87 @@ public abstract class BaseSpringAIService implements SpringAIService { // sendSseTypingMessage(messageProtobuf, emitter); // 判断是否开启大模型 if (robot.getLlm().isEnabled()) { - - - + // 启用大模型 + processLlmResponse(query, robot, messageProtobufQuery, messageProtobufReply, emitter); + } else { + // 未开启大模型,关键词匹配,使用搜索 + processSearchResponse(query, robot, messageProtobufQuery, messageProtobufReply, emitter); } } + private void processLlmResponse(String query, RobotProtobuf robot, MessageProtobuf messageProtobufQuery, + MessageProtobuf messageProtobufReply, SseEmitter emitter) { + // + String prompt = ""; + if (StringUtils.hasText(robot.getKbUid()) && robot.getIsKbEnabled()) { + List contentList = springAIVectorService.get().searchText(query, robot.getKbUid()); + if (contentList.isEmpty()) { + // 直接返回未找到相关问题答案 + String answer = RobotConsts.ROBOT_UNMATCHED; + processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter); + return; + } + String context = String.join("\n", contentList); + // TODO: 根据配置,拉取历史聊天记录 + // String history = ""; + prompt = buildKbPrompt(robot.getLlm().getPrompt(), query, context); + } else { + prompt = robot.getLlm().getPrompt(); + } + // TODO: 返回消息中携带消息搜索结果(来源依据) + // + List messages = new ArrayList<>(); + messages.add(new SystemMessage(prompt)); + messages.add(new UserMessage(query)); + log.info("BaseSpringAIService sendSseMemberMessage messages {}", messages); + // + Prompt aiPrompt = new Prompt(messages); + processPromptSSE(aiPrompt, messageProtobufQuery, messageProtobufReply, emitter); + } - private void processLlmResponse(MessageProtobuf messageProtobufQuery, MessageProtobuf messageProtobufReply, - SseEmitter emitter) { - // - String prompt = ""; - if (StringUtils.hasText(robot.getKbUid()) && robot.getIsKbEnabled()) { - List contentList = springAIVectorService.get().searchText(query, robot.getKbUid()); - if (contentList.isEmpty()) { - // 直接返回未找到相关问题答案 - messageProtobufReply.setType(MessageTypeEnum.TEXT); - messageProtobufReply.setContent(RobotConsts.ROBOT_UNMATCHED); - messageProtobufReply.setClient(ClientEnum.SYSTEM); - // 保存消息到数据库 - persistMessage(messageProtobufQuery, messageProtobufReply); - String messageJson = messageProtobufReply.toJson(); - try { - // 发送SSE事件 - emitter.send(SseEmitter.event() - .data(messageJson) - .id(messageProtobufReply.getUid()) - .name("message")); - } catch (Exception e) { - log.error("BaseSpringAIService sendSseMemberMessage Error sending SSE event 1:", e); - emitter.completeWithError(e); - } - return; - } - String context = String.join("\n", contentList); - // TODO: 根据配置,拉取历史聊天记录 - // String history = ""; - prompt = buildKbPrompt(robot.getLlm().getPrompt(), query, context); + private void processSearchResponse(String query, RobotProtobuf robot, MessageProtobuf messageProtobufQuery, + MessageProtobuf messageProtobufReply, SseEmitter emitter) { + + if (StringUtils.hasText(robot.getKbUid()) && robot.getIsKbEnabled()) { + List contentList = springAIVectorService.get().searchText(query, robot.getKbUid()); + if (contentList.isEmpty()) { + // 直接返回未找到相关问题答案 + String answer = RobotConsts.ROBOT_UNMATCHED; + processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter); + return; } else { - prompt = robot.getLlm().getPrompt(); - } - // TODO: 返回消息中携带消息搜索结果(来源依据) - // - List messages = new ArrayList<>(); - messages.add(new SystemMessage(prompt)); - messages.add(new UserMessage(query)); - log.info("BaseSpringAIService sendSseMemberMessage messages {}", messages); - // - Prompt aiPrompt = new Prompt(messages); - processPromptSSE(aiPrompt, messageProtobufQuery, messageProtobufReply, emitter); + // 搜索到内容,返回搜索内容 + String answer = String.join("\n", contentList); + processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter); } + } else { + // 未设置知识库 + // 直接返回未找到相关问题答案 + String answer = RobotConsts.ROBOT_UNMATCHED; + processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter); + } + } - + private void processAnswerMessage(String answer, RobotProtobuf robot, MessageProtobuf messageProtobufQuery, + MessageProtobuf messageProtobufReply, SseEmitter emitter) { + messageProtobufReply.setType(MessageTypeEnum.TEXT); + messageProtobufReply.setContent(answer); + messageProtobufReply.setClient(ClientEnum.SYSTEM); + // 保存消息到数据库 + persistMessage(messageProtobufQuery, messageProtobufReply); + String messageJson = messageProtobufReply.toJson(); + try { + // 发送SSE事件 + emitter.send(SseEmitter.event() + .data(messageJson) + .id(messageProtobufReply.getUid()) + .name("message")); + } catch (Exception e) { + log.error("BaseSpringAIService sendSseMemberMessage Error sending SSE event 1:", e); + emitter.completeWithError(e); + } + } @Override public String generateFaqPairsAsync(String chunk) {