diff --git a/modules/ai/src/main/java/com/bytedesk/ai/robot/RobotEventListener.java b/modules/ai/src/main/java/com/bytedesk/ai/robot/RobotEventListener.java index 6fbc6b5c29..171ad006ce 100644 --- a/modules/ai/src/main/java/com/bytedesk/ai/robot/RobotEventListener.java +++ b/modules/ai/src/main/java/com/bytedesk/ai/robot/RobotEventListener.java @@ -2,7 +2,7 @@ * @Author: jackning 270580156@qq.com * @Date: 2024-06-12 07:17:13 * @LastEditors: jackning 270580156@qq.com - * @LastEditTime: 2025-03-11 17:03:53 + * @LastEditTime: 2025-03-11 17:29:15 * @Description: bytedesk.com https://github.com/Bytedesk/bytedesk * Please be aware of the BSL license restrictions before installing Bytedesk IM – * selling, reselling, or hosting Bytedesk IM as a service is a breach of the terms and automatically terminates your rights under the license. @@ -103,7 +103,8 @@ public class RobotEventListener { return; } String threadTopic = threadProtobuf.getTopic(); - if (threadProtobuf.getType().equals(ThreadTypeEnum.LLM) || threadProtobuf.getType().equals(ThreadTypeEnum.ROBOT)) { + if (threadProtobuf.getType().equals(ThreadTypeEnum.LLM) || + threadProtobuf.getType().equals(ThreadTypeEnum.ROBOT)) { log.info("robot robot threadTopic {}, thread.type {}", threadTopic, threadProtobuf.getType()); processRobotThreadMessage(query, threadTopic, threadProtobuf, messageProtobuf); } diff --git a/modules/ai/src/main/java/com/bytedesk/ai/robot/RobotService.java b/modules/ai/src/main/java/com/bytedesk/ai/robot/RobotService.java new file mode 100644 index 0000000000..477813845f --- /dev/null +++ b/modules/ai/src/main/java/com/bytedesk/ai/robot/RobotService.java @@ -0,0 +1,197 @@ +/* + * @Author: jackning 270580156@qq.com + * @Date: 2025-03-11 17:29:51 + * @LastEditors: jackning 270580156@qq.com + * @LastEditTime: 2025-03-11 17:47:59 + * @Description: bytedesk.com https://github.com/Bytedesk/bytedesk + * Please be aware of the BSL license restrictions before installing Bytedesk IM – + * selling, reselling, or hosting Bytedesk IM as a service is a breach of the terms and automatically terminates your rights under the license. + * Business Source License 1.1: https://github.com/Bytedesk/bytedesk/blob/main/LICENSE + * contact: 270580156@qq.com + * + * Copyright (c) 2025 by bytedesk.com, All Rights Reserved. + */ +package com.bytedesk.ai.robot; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +// import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.stereotype.Service; +import org.springframework.util.SerializationUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import com.alibaba.fastjson2.JSON; +import com.bytedesk.ai.robot_message.RobotMessageUtils; +import com.bytedesk.ai.springai.ollama.SpringAIOllamaService; +import com.bytedesk.ai.springai.spring.SpringAIVectorService; +import com.bytedesk.core.message.IMessageSendService; +import com.bytedesk.core.message.MessageProtobuf; +import com.bytedesk.core.message.MessageTypeEnum; +import com.bytedesk.core.rbac.user.UserProtobuf; +import com.bytedesk.core.rbac.user.UserTypeEnum; +import com.bytedesk.core.thread.ThreadEntity; +import com.bytedesk.core.thread.ThreadProtobuf; +import com.bytedesk.core.thread.ThreadRestService; +import com.bytedesk.core.uid.UidUtils; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@Service +@RequiredArgsConstructor +public class RobotService { + + private final RobotRestService robotRestService; + + private final ThreadRestService threadRestService; + + protected final IMessageSendService messageSendService; + + private final UidUtils uidUtils; + + // private final Optional bytedeskOllamaChatModel; + + protected final Optional springAIVectorService; + + private final SpringAIOllamaService springAIOllamaService; + + protected void processPromptSSE(String messageJson, SseEmitter emitter) { + // + MessageProtobuf messageProtobuf = JSON.parseObject(messageJson, MessageProtobuf.class); + MessageTypeEnum messageType = messageProtobuf.getType(); + if (messageType.equals(MessageTypeEnum.STREAM)) { + return; + } + String query = messageProtobuf.getContent(); + log.info("robot processMessage {}", query); + ThreadProtobuf threadProtobuf = messageProtobuf.getThread(); + if (threadProtobuf == null) { + throw new RuntimeException("thread is null"); + } + // 暂时仅支持文字消息类型,其他消息类型,大模型暂不处理。 + if (!messageType.equals(MessageTypeEnum.TEXT)) { + return; + } + String threadTopic = threadProtobuf.getTopic(); + ThreadEntity thread = threadRestService.findFirstByTopic(threadTopic) + .orElseThrow(() -> new RuntimeException("thread with topic " + threadTopic + + " not found")); + UserProtobuf agent = JSON.parseObject(thread.getAgent(), UserProtobuf.class); + if (agent.getType().equals(UserTypeEnum.ROBOT.name())) { + log.info("robot thread reply"); + RobotEntity robot = robotRestService.findByUid(agent.getUid()) + .orElseThrow(() -> new RuntimeException("robot " + agent.getUid() + " not found")); + // + MessageProtobuf message = RobotMessageUtils.createRobotMessage(thread, threadProtobuf, robot, + messageProtobuf); + // + MessageProtobuf clonedMessage = SerializationUtils.clone(message); + clonedMessage.setUid(uidUtils.getUid()); + clonedMessage.setType(MessageTypeEnum.PROCESSING); + messageSendService.sendProtobufMessage(clonedMessage); + // + String prompt = ""; + if (StringUtils.hasText(robot.getKbUid()) && robot.isKbEnabled()) { + List contentList = springAIVectorService.get().searchText(query, robot.getKbUid()); + String context = String.join("\n", contentList); + prompt = springAIOllamaService.buildKbPrompt(robot.getLlm().getPrompt(), query, context); + } else { + prompt = robot.getLlm().getPrompt(); + } + // + List messages = new ArrayList<>(); + messages.add(new SystemMessage(prompt)); + messages.add(new UserMessage(query)); + // + // Prompt aiPrompt = new Prompt(messages); + // + // springAIOllamaService.sendSseMessage(query, robot, message); + // + // ollamaProcess(robot, aiPrompt, threadProtobuf, message, emitter); + } + } + + // private void ollamaProcess(RobotEntity robot, Prompt aiPrompt, ThreadProtobuf thread, MessageProtobuf message, SseEmitter emitter) { + // // 你的处理逻辑 + // bytedeskOllamaChatModel.ifPresentOrElse( + // model -> { + + // model.stream(aiPrompt).subscribe( + // response -> { + // try { + // if (response != null) { + // List generations = response.getResults(); + // for (Generation generation : generations) { + // AssistantMessage assistantMessage = generation.getOutput(); + // String textContent = assistantMessage.getText(); + // // + // message.setContent(textContent); + // message.setType(MessageTypeEnum.STREAM); + // // 发送SSE事件 + // emitter.send(SseEmitter.event() + // .data(JSON.toJSONString(message)) + // .id(message.getUid()) + // .name("message")); + // } + // } + // } catch (Exception e) { + // log.error("Error sending SSE event", e); + + // emitter.completeWithError(e); + // } + // }, + // error -> { + // log.error("Ollama API SSE error: ", error); + // try { + // message.setType(MessageTypeEnum.ERROR); + // message.setContent("服务暂时不可用,请稍后重试"); + // // + // emitter.send(SseEmitter.event() + // .data(JSON.toJSONString(message)) + // .id(message.getUid()) + // .name("error")); + // emitter.complete(); + // } catch (Exception e) { + // emitter.completeWithError(e); + // } + // }, + // () -> { + // try { + // // 发送流结束标记 + // message.setType(MessageTypeEnum.STREAM_END); + // message.setContent(""); // 或者可以是任何结束标记 + // emitter.send(SseEmitter.event() + // .data(JSON.toJSONString(message)) + // .id(message.getUid()) + // .name("end")); + // emitter.complete(); + // } catch (Exception e) { + // log.error("Error completing SSE", e); + // } + // }); + // }, + // () -> { + // try { + // // 发送流结束标记 + // message.setType(MessageTypeEnum.STREAM_END); + // message.setContent("Ollama service is not available"); // 或者可以是任何结束标记 + // emitter.send(SseEmitter.event() + // .data(JSON.toJSONString(message)) + // .id(message.getUid()) + // .name("ollama_error")); + // emitter.complete(); + // } catch (Exception e) { + // emitter.completeWithError(e); + // } + // }); + // } + + +} diff --git a/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java b/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java index eaae970e43..9f5e410f4d 100644 --- a/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java +++ b/modules/ai/src/main/java/com/bytedesk/ai/springai/base/BaseSpringAIService.java @@ -9,15 +9,26 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.util.Assert; +import org.springframework.util.SerializationUtils; import org.springframework.util.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import com.alibaba.fastjson2.JSON; import com.bytedesk.ai.robot.RobotConsts; import com.bytedesk.ai.robot.RobotEntity; +import com.bytedesk.ai.robot.RobotRestService; +import com.bytedesk.ai.robot_message.RobotMessageUtils; import com.bytedesk.ai.springai.spring.SpringAIService; import com.bytedesk.ai.springai.spring.SpringAIVectorService; import com.bytedesk.core.message.IMessageSendService; import com.bytedesk.core.message.MessageProtobuf; +import com.bytedesk.core.message.MessageTypeEnum; +import com.bytedesk.core.rbac.user.UserProtobuf; +import com.bytedesk.core.rbac.user.UserTypeEnum; +import com.bytedesk.core.thread.ThreadEntity; +import com.bytedesk.core.thread.ThreadProtobuf; +import com.bytedesk.core.thread.ThreadRestService; +import com.bytedesk.core.uid.UidUtils; import lombok.extern.slf4j.Slf4j; @@ -26,11 +37,21 @@ public abstract class BaseSpringAIService implements SpringAIService { protected final Optional springAIVectorService; protected final IMessageSendService messageSendService; + protected final UidUtils uidUtils; + protected final RobotRestService robotRestService; + protected final ThreadRestService threadRestService; protected BaseSpringAIService(Optional springAIVectorService, - IMessageSendService messageSendService) { + IMessageSendService messageSendService, + UidUtils uidUtils, + RobotRestService robotRestService, + ThreadRestService threadRestService + ) { this.springAIVectorService = springAIVectorService; this.messageSendService = messageSendService; + this.uidUtils = uidUtils; + this.robotRestService = robotRestService; + this.threadRestService = threadRestService; } @Override @@ -59,9 +80,59 @@ public abstract class BaseSpringAIService implements SpringAIService { } @Override - public void sendKbaseSseMessage(String message, SseEmitter emitter) { - Assert.hasText(message, "Message must not be empty"); + public void sendSseMessage(String messageJson, SseEmitter emitter) { + Assert.hasText(messageJson, "Message must not be empty"); Assert.notNull(emitter, "SseEmitter must not be null"); + // + MessageProtobuf messageProtobuf = JSON.parseObject(messageJson, MessageProtobuf.class); + MessageTypeEnum messageType = messageProtobuf.getType(); + if (messageType.equals(MessageTypeEnum.STREAM)) { + return; + } + String query = messageProtobuf.getContent(); + log.info("robot processMessage {}", query); + ThreadProtobuf threadProtobuf = messageProtobuf.getThread(); + if (threadProtobuf == null) { + throw new RuntimeException("thread is null"); + } + // 暂时仅支持文字消息类型,其他消息类型,大模型暂不处理。 + if (!messageType.equals(MessageTypeEnum.TEXT)) { + return; + } + String threadTopic = threadProtobuf.getTopic(); + ThreadEntity thread = threadRestService.findFirstByTopic(threadTopic) + .orElseThrow(() -> new RuntimeException("thread with topic " + threadTopic + + " not found")); + UserProtobuf agent = JSON.parseObject(thread.getAgent(), UserProtobuf.class); + if (agent.getType().equals(UserTypeEnum.ROBOT.name())) { + log.info("robot thread reply"); + RobotEntity robot = robotRestService.findByUid(agent.getUid()) + .orElseThrow(() -> new RuntimeException("robot " + agent.getUid() + " not found")); + // + MessageProtobuf message = RobotMessageUtils.createRobotMessage(thread, threadProtobuf, robot, + messageProtobuf); + // + MessageProtobuf clonedMessage = SerializationUtils.clone(message); + clonedMessage.setUid(uidUtils.getUid()); + clonedMessage.setType(MessageTypeEnum.PROCESSING); + messageSendService.sendProtobufMessage(clonedMessage); + // + String prompt = ""; + if (StringUtils.hasText(robot.getKbUid()) && robot.isKbEnabled()) { + List contentList = springAIVectorService.get().searchText(query, robot.getKbUid()); + String context = String.join("\n", contentList); + prompt = buildKbPrompt(robot.getLlm().getPrompt(), query, context); + } else { + prompt = robot.getLlm().getPrompt(); + } + // + List messages = new ArrayList<>(); + messages.add(new SystemMessage(prompt)); + messages.add(new UserMessage(query)); + // + Prompt aiPrompt = new Prompt(messages); + processPromptSSE(robot, aiPrompt, threadProtobuf, message, emitter); + } } @Override @@ -105,7 +176,7 @@ public abstract class BaseSpringAIService implements SpringAIService { } } - protected String buildKbPrompt(String systemPrompt, String query, String context) { + public String buildKbPrompt(String systemPrompt, String query, String context) { return systemPrompt + "\n" + "用户查询: " + query + "\n" + "历史聊天记录: " + "\n" + diff --git a/modules/ai/src/main/java/com/bytedesk/ai/springai/spring/SpringAIService.java b/modules/ai/src/main/java/com/bytedesk/ai/springai/spring/SpringAIService.java index 37bcae0ca9..3292a87d11 100644 --- a/modules/ai/src/main/java/com/bytedesk/ai/springai/spring/SpringAIService.java +++ b/modules/ai/src/main/java/com/bytedesk/ai/springai/spring/SpringAIService.java @@ -2,7 +2,7 @@ * @Author: jackning 270580156@qq.com * @Date: 2025-02-26 14:48:03 * @LastEditors: jackning 270580156@qq.com - * @LastEditTime: 2025-03-11 16:55:53 + * @LastEditTime: 2025-03-11 17:28:28 * @Description: bytedesk.com https://github.com/Bytedesk/bytedesk * Please be aware of the BSL license restrictions before installing Bytedesk IM – * selling, reselling, or hosting Bytedesk IM as a service is a breach of the terms and automatically terminates your rights under the license. @@ -37,7 +37,7 @@ public interface SpringAIService { * @param message 消息 * @param emitter SseEmitter */ - void sendKbaseSseMessage(String message, SseEmitter emitter); + void sendSseMessage(String message, SseEmitter emitter); /** * 异步生成FAQ对 diff --git a/modules/service/src/main/java/com/bytedesk/service/visitor/VisitorRestControllerAnonymous.java b/modules/service/src/main/java/com/bytedesk/service/visitor/VisitorRestControllerAnonymous.java index daa8152db1..1937c7ed16 100644 --- a/modules/service/src/main/java/com/bytedesk/service/visitor/VisitorRestControllerAnonymous.java +++ b/modules/service/src/main/java/com/bytedesk/service/visitor/VisitorRestControllerAnonymous.java @@ -2,7 +2,7 @@ * @Author: jackning 270580156@qq.com * @Date: 2024-01-29 16:21:24 * @LastEditors: jackning 270580156@qq.com - * @LastEditTime: 2025-03-11 16:12:33 + * @LastEditTime: 2025-03-11 17:25:00 * @Description: bytedesk.com https://github.com/Bytedesk/bytedesk * Please be aware of the BSL license restrictions before installing Bytedesk IM – * selling, reselling, or hosting Bytedesk IM as a service is a breach of the terms and automatically terminates your rights under the license. @@ -15,13 +15,19 @@ package com.bytedesk.service.visitor; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import com.bytedesk.core.annotation.ApiRateLimiter; import com.bytedesk.core.config.BytedeskEventPublisher; @@ -59,6 +65,8 @@ public class VisitorRestControllerAnonymous { private final BytedeskEventPublisher bytedeskEventPublisher; + private final ExecutorService executorService = Executors.newCachedThreadPool(); + @VisitorAnnotation(title = "visitor", action = "init", description = "init visitor") @ApiRateLimiter(value = 10.0, timeout = 1) @PostMapping("/init") @@ -142,5 +150,39 @@ public class VisitorRestControllerAnonymous { return ResponseEntity.ok(JsonResult.success(json)); } + @VisitorAnnotation(title = "visitor", action = "sendSseMessage", description = "sendSseMessage") + @GetMapping(value = "/message/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter sendSseMessage(@RequestParam(value = "message") String message) { + + SseEmitter emitter = new SseEmitter(180_000L); // 3分钟超时 + + executorService.execute(() -> { + try { + // springAIOllamaService.processPromptSSE(message, emitter); + } catch (Exception e) { + log.error("Error processing SSE request", e); + emitter.completeWithError(e); + } + }); + + // 添加超时和完成时的回调 + emitter.onTimeout(() -> { + log.warn("SSE connection timed out"); + emitter.complete(); + }); + + emitter.onCompletion(() -> { + log.info("SSE connection completed"); + }); + + return emitter; + } + + // 在 Bean 销毁时关闭线程池 + public void destroy() { + if (executorService != null && !executorService.isShutdown()) { + executorService.shutdown(); + } + } }