This commit is contained in:
jack ning
2025-04-15 14:37:45 +08:00
parent ae4d752732
commit f7e614dee2

View File

@@ -102,59 +102,87 @@ public abstract class BaseSpringAIService implements SpringAIService {
// sendSseTypingMessage(messageProtobuf, emitter);
// 判断是否开启大模型
if (robot.getLlm().isEnabled()) {
// 启用大模型
processLlmResponse(query, robot, messageProtobufQuery, messageProtobufReply, emitter);
} else {
// 未开启大模型,关键词匹配,使用搜索
processSearchResponse(query, robot, messageProtobufQuery, messageProtobufReply, emitter);
}
}
private void processLlmResponse(String query, RobotProtobuf robot, MessageProtobuf messageProtobufQuery,
MessageProtobuf messageProtobufReply, SseEmitter emitter) {
//
String prompt = "";
if (StringUtils.hasText(robot.getKbUid()) && robot.getIsKbEnabled()) {
List<String> contentList = springAIVectorService.get().searchText(query, robot.getKbUid());
if (contentList.isEmpty()) {
// 直接返回未找到相关问题答案
String answer = RobotConsts.ROBOT_UNMATCHED;
processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter);
return;
}
String context = String.join("\n", contentList);
// TODO: 根据配置,拉取历史聊天记录
// String history = "";
prompt = buildKbPrompt(robot.getLlm().getPrompt(), query, context);
} else {
prompt = robot.getLlm().getPrompt();
}
// TODO: 返回消息中携带消息搜索结果(来源依据)
//
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(prompt));
messages.add(new UserMessage(query));
log.info("BaseSpringAIService sendSseMemberMessage messages {}", messages);
//
Prompt aiPrompt = new Prompt(messages);
processPromptSSE(aiPrompt, messageProtobufQuery, messageProtobufReply, emitter);
}
private void processLlmResponse(MessageProtobuf messageProtobufQuery, MessageProtobuf messageProtobufReply,
SseEmitter emitter) {
//
String prompt = "";
if (StringUtils.hasText(robot.getKbUid()) && robot.getIsKbEnabled()) {
List<String> contentList = springAIVectorService.get().searchText(query, robot.getKbUid());
if (contentList.isEmpty()) {
// 直接返回未找到相关问题答案
messageProtobufReply.setType(MessageTypeEnum.TEXT);
messageProtobufReply.setContent(RobotConsts.ROBOT_UNMATCHED);
messageProtobufReply.setClient(ClientEnum.SYSTEM);
// 保存消息到数据库
persistMessage(messageProtobufQuery, messageProtobufReply);
String messageJson = messageProtobufReply.toJson();
try {
// 发送SSE事件
emitter.send(SseEmitter.event()
.data(messageJson)
.id(messageProtobufReply.getUid())
.name("message"));
} catch (Exception e) {
log.error("BaseSpringAIService sendSseMemberMessage Error sending SSE event 1", e);
emitter.completeWithError(e);
}
return;
}
String context = String.join("\n", contentList);
// TODO: 根据配置,拉取历史聊天记录
// String history = "";
prompt = buildKbPrompt(robot.getLlm().getPrompt(), query, context);
private void processSearchResponse(String query, RobotProtobuf robot, MessageProtobuf messageProtobufQuery,
MessageProtobuf messageProtobufReply, SseEmitter emitter) {
if (StringUtils.hasText(robot.getKbUid()) && robot.getIsKbEnabled()) {
List<String> contentList = springAIVectorService.get().searchText(query, robot.getKbUid());
if (contentList.isEmpty()) {
// 直接返回未找到相关问题答案
String answer = RobotConsts.ROBOT_UNMATCHED;
processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter);
return;
} else {
prompt = robot.getLlm().getPrompt();
}
// TODO: 返回消息中携带消息搜索结果(来源依据)
//
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(prompt));
messages.add(new UserMessage(query));
log.info("BaseSpringAIService sendSseMemberMessage messages {}", messages);
//
Prompt aiPrompt = new Prompt(messages);
processPromptSSE(aiPrompt, messageProtobufQuery, messageProtobufReply, emitter);
// 搜索到内容,返回搜索内容
String answer = String.join("\n", contentList);
processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter);
}
} else {
// 未设置知识库
// 直接返回未找到相关问题答案
String answer = RobotConsts.ROBOT_UNMATCHED;
processAnswerMessage(answer, robot, messageProtobufQuery, messageProtobufReply, emitter);
}
}
private void processAnswerMessage(String answer, RobotProtobuf robot, MessageProtobuf messageProtobufQuery,
MessageProtobuf messageProtobufReply, SseEmitter emitter) {
messageProtobufReply.setType(MessageTypeEnum.TEXT);
messageProtobufReply.setContent(answer);
messageProtobufReply.setClient(ClientEnum.SYSTEM);
// 保存消息到数据库
persistMessage(messageProtobufQuery, messageProtobufReply);
String messageJson = messageProtobufReply.toJson();
try {
// 发送SSE事件
emitter.send(SseEmitter.event()
.data(messageJson)
.id(messageProtobufReply.getUid())
.name("message"));
} catch (Exception e) {
log.error("BaseSpringAIService sendSseMemberMessage Error sending SSE event 1", e);
emitter.completeWithError(e);
}
}
@Override
public String generateFaqPairsAsync(String chunk) {