|
@@ -1,13 +1,16 @@
|
|
|
package io.github.qifan777.knowledge.ai.message;
|
|
|
|
|
|
+import com.baomidou.mybatisplus.extension.toolkit.Db;
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import io.github.qifan777.knowledge.ai.agent.Agent;
|
|
|
import io.github.qifan777.knowledge.ai.message.dto.AiMessageInput;
|
|
|
import io.github.qifan777.knowledge.ai.message.dto.AiMessageWrapper;
|
|
|
import io.github.qifan777.knowledge.context.UserContext;
|
|
|
+import io.github.qifan777.knowledge.domain.po.User;
|
|
|
import lombok.AllArgsConstructor;
|
|
|
import lombok.SneakyThrows;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.babyfish.jimmer.sql.ast.PropExpression;
|
|
|
import org.springframework.ai.chat.client.ChatClient;
|
|
|
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
|
|
|
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
|
|
@@ -79,7 +82,6 @@ public class AiMessageController {
|
|
|
@SneakyThrows
|
|
|
@PostMapping(value = "chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
|
|
public Flux<ServerSentEvent<String>> chat(@RequestPart String input, @RequestPart(required = false) MultipartFile file) {
|
|
|
- System.out.println(UserContext.getThreadUserId());
|
|
|
AiMessageWrapper aiMessageWrapper = objectMapper.readValue(input, AiMessageWrapper.class);
|
|
|
String[] functionBeanNames = new String[0];
|
|
|
// 如果启用Agent则获取Agent的bean
|
|
@@ -89,10 +91,14 @@ public class AiMessageController {
|
|
|
functionBeanNames = new String[beansWithAnnotation.size()];
|
|
|
functionBeanNames = beansWithAnnotation.keySet().toArray(functionBeanNames);
|
|
|
}
|
|
|
+ String userId = UserContext.getThreadUserId();
|
|
|
return ChatClient.create(chatModel).prompt()
|
|
|
// 启用文件问答
|
|
|
.system(promptSystemSpec -> useFile(promptSystemSpec, file))
|
|
|
- .user(promptUserSpec -> toPrompt(promptUserSpec, aiMessageWrapper.getMessage()))
|
|
|
+ .user(promptUserSpec -> {
|
|
|
+ toPrompt(promptUserSpec, aiMessageWrapper.getMessage());
|
|
|
+ promptUserSpec.param("userName", Db.lambdaQuery(User.class).select().eq(User::getId,userId).one().getNickname());
|
|
|
+ })
|
|
|
// agent列表
|
|
|
.functions(functionBeanNames)
|
|
|
.advisors(advisorSpec -> {
|