瀏覽代碼

springai添加可访问es数据库,分词返回文档连接

Taohongrun 5 月之前
父節點
當前提交
11b6d71bdb

+ 13 - 0
pom.xml

@@ -28,6 +28,7 @@
         <mybatis-plus-boot-starter.version>3.5.7</mybatis-plus-boot-starter.version>
         <mybatis-plus-boot-starter.version>3.5.7</mybatis-plus-boot-starter.version>
         <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
         <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
         <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
         <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
+        <elasticsearch.version>8.11.3</elasticsearch.version>
     </properties>
     </properties>
     <dependencies>
     <dependencies>
         <dependency>
         <dependency>
@@ -166,6 +167,18 @@
             <groupId>org.springframework.boot</groupId>
             <groupId>org.springframework.boot</groupId>
             <artifactId>spring-boot-starter-data-redis</artifactId>
             <artifactId>spring-boot-starter-data-redis</artifactId>
         </dependency>
         </dependency>
+
+        <dependency>
+            <groupId>co.elastic.clients</groupId>
+            <artifactId>elasticsearch-java</artifactId>
+            <version>${elasticsearch.version}</version>
+        </dependency>
+
+        <dependency>
+            <groupId>jakarta.json</groupId>
+            <artifactId>jakarta.json-api</artifactId>
+            <version>2.0.1</version>
+        </dependency>
     </dependencies>
     </dependencies>
     <dependencyManagement>
     <dependencyManagement>
         <dependencies>
         <dependencies>

+ 63 - 0
src/main/java/io/github/qifan777/knowledge/ai/agent/documentSearch/DocumentEsSearchAgent.java

@@ -0,0 +1,63 @@
+package io.github.qifan777.knowledge.ai.agent.documentSearch;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import io.github.qifan777.knowledge.ai.agent.AbstractAgent;
+import io.github.qifan777.knowledge.ai.agent.Agent;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.model.ChatModel;
+
+import org.springframework.context.annotation.Description;
+
+import java.util.function.Function;
+
+
+/**
+ * <p>
+ *
+ * </p>
+ *
+ * @author taohongrun
+ * @since 2025/3/17
+ */
+@Agent
+@Description("回答用户查询文档相关问题")
+@Slf4j
+@RequiredArgsConstructor
+public class DocumentEsSearchAgent extends AbstractAgent implements Function<DocumentEsSearchAgent.Request, String> {
+
+    private final ChatModel chatModel;
+
+    @Override
+    public String apply(DocumentEsSearchAgent.Request request) {
+        log.info("用户想要查询的文档名称: {}", request.fileName);
+        log.info("用户的角色类型: {}", request.roleName);
+        return ChatClient.create(chatModel)
+                .prompt()
+                .system(s -> s.text("用户查询文档相关问题,用户角色权限:{role},文档名称:{fileName},问题:{query}")
+                        .param("role", request.roleName)
+                        .param("fileName", request.fileName)
+                        .param("query", request.query()))
+                .functions(getFunctions(SearchDocumentFunction.class))
+                .user(userSpec -> userSpec
+                        .param("roleName", request.roleName)
+                        .param("fileName", request.fileName)// 传递参数
+                        .text(request.query()))
+                .call()
+                .content();
+    }
+
+    public record Request(@JsonProperty(required = true)
+                          @JsonPropertyDescription("用户想要查询的文档名称")
+                          String fileName
+            ,
+                          @JsonProperty(required = true)
+                          @JsonPropertyDescription("用户的角色类型,如老师,学生")
+                          String roleName
+            ,
+                          @JsonProperty(required = true) @JsonPropertyDescription(value = "用户原始的提问") String query) {
+    }
+}

+ 80 - 0
src/main/java/io/github/qifan777/knowledge/ai/agent/documentSearch/SearchDocumentFunction.java

@@ -0,0 +1,80 @@
+package io.github.qifan777.knowledge.ai.agent.documentSearch;
+
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import io.github.qifan777.knowledge.domain.po.SearchDocument;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.context.annotation.Description;
+import org.springframework.stereotype.Component;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * <p>
+ *
+ * </p>
+ *
+ * @author taohongrun
+ * @since 2025/3/17
+ */
+@Component
+@Description("返回用户想要查询的文档连接")
+@Slf4j
+@RequiredArgsConstructor
+public class SearchDocumentFunction implements Function<SearchDocumentFunction.Request, String> {
+    private final ElasticsearchClient client;
+
+    @Override
+    public String apply(SearchDocumentFunction.Request request) {
+        //获取的参数
+        log.info("用户想要查询的文档名称: {}", request.fileName);
+        log.info("用户的角色类型: {}", request.roleName);
+
+        try {
+            List<String> urlList = new ArrayList<>();
+
+            urlList = client.search(s -> s.index("documents_index")
+                                    .query(q -> q
+                                            .bool(b -> b
+                                                    .must(m -> m
+                                                            .match(t -> t
+                                                                    .field("title")
+                                                                    .query(request.fileName)))
+                                                    //"permissions"为数组,判断是否包含roleName
+
+                                                    .filter(m -> m
+                                                            .term(t -> t
+                                                                    .field("permissions")
+                                                                    .value(request.roleName)))
+                                            ))
+                            , SearchDocument.class)
+                    .hits()
+                    .hits()
+                    .stream()
+                    .map(h -> {
+                        assert h.source() != null;
+                        return h.source().getUrl();
+                    })
+                    .collect(Collectors.toList());
+            log.info("查询结果: {}", urlList);
+            return String.join("\n", urlList);
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public record Request(@JsonProperty(required = true)
+                          @JsonPropertyDescription("用户想要查询的文档名称")
+                          String fileName
+            ,
+                          @JsonProperty(required = true)
+                          @JsonPropertyDescription("用户的角色类型,如老师,学生")
+                          String roleName) {
+    }
+}

+ 20 - 8
src/main/java/io/github/qifan777/knowledge/ai/message/AiMessageController.java

@@ -85,7 +85,7 @@ public class AiMessageController {
         AiMessageWrapper aiMessageWrapper = objectMapper.readValue(input, AiMessageWrapper.class);
         AiMessageWrapper aiMessageWrapper = objectMapper.readValue(input, AiMessageWrapper.class);
         String[] functionBeanNames = new String[0];
         String[] functionBeanNames = new String[0];
         // 如果启用Agent则获取Agent的bean
         // 如果启用Agent则获取Agent的bean
-        if (aiMessageWrapper.getParams().getEnableAgent()) {
+        if (true) {
             // 获取带有Agent注解的bean
             // 获取带有Agent注解的bean
             Map<String, Object> beansWithAnnotation = applicationContext.getBeansWithAnnotation(Agent.class);
             Map<String, Object> beansWithAnnotation = applicationContext.getBeansWithAnnotation(Agent.class);
             functionBeanNames = new String[beansWithAnnotation.size()];
             functionBeanNames = new String[beansWithAnnotation.size()];
@@ -156,15 +156,27 @@ public class AiMessageController {
 
 
     @SneakyThrows
     @SneakyThrows
     public void useFile(ChatClient.PromptSystemSpec spec, MultipartFile file) {
     public void useFile(ChatClient.PromptSystemSpec spec, MultipartFile file) {
-        if (file == null) return;
+        String username = Db.lambdaQuery(User.class)
+                .select().eq(User::getId, UserContext.getThreadUserId())
+                .one().getNickname();
+        if (file == null) {
+            Message message = new PromptTemplate("""
+               你是一个校园管理助手,请遵守以下规则:
+        1. 当用户查询我的成绩时,只能查询学生名为{username}的成绩,不能查询其他人的成绩
+        2. 其他信息查询需保持客观中立
+        3. 查询文档文件时需要严格记住当前用户的角色权限为主任
+        """).createMessage(Map.of( "username", username));
+            spec.text(message.getText());
+            return;
+        }
         String content = new TikaDocumentReader(new InputStreamResource(file.getInputStream())).get().get(0).getText();
         String content = new TikaDocumentReader(new InputStreamResource(file.getInputStream())).get().get(0).getText();
         Message message = new PromptTemplate("""
         Message message = new PromptTemplate("""
-                已下内容是额外的知识,在你回答问题时可以参考下面的内容
-                ---------------------
-                {context}
-                ---------------------
-                """)
-                .createMessage(Map.of("context", content));
+               你是一个校园管理助手,请遵守以下规则:
+        1. 当用户查询我的成绩时,只能查询学生名为{username}的成绩,不能查询其他人的成绩
+        2. 其他信息查询需保持客观中立
+        {context}  // 保留原有文件内容逻辑
+        """)
+                .createMessage(Map.of("context", content, "username", username));
         spec.text(message.getText());
         spec.text(message.getText());
     }
     }
 
 

+ 30 - 0
src/main/java/io/github/qifan777/knowledge/config/ElasticSearchConf.java

@@ -0,0 +1,30 @@
+package io.github.qifan777.knowledge.config;
+
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import co.elastic.clients.json.jackson.JacksonJsonpMapper;
+import co.elastic.clients.transport.rest_client.RestClientTransport;
+import org.apache.http.HttpHost;
+import org.elasticsearch.client.RestClient;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+/**
+ * <p>
+ *
+ * </p>
+ *
+ * @author taohongrun
+ * @since 2025/3/17
+ */
+@Configuration
+public class ElasticSearchConf {
+    @Bean
+    public ElasticsearchClient elasticSearchClient() {
+        RestClient restClient=RestClient.builder(new HttpHost("58.87.69.234", 9200))
+                .build();
+
+        RestClientTransport restClientTransport = new RestClientTransport(restClient, new JacksonJsonpMapper());
+
+        return new ElasticsearchClient(restClientTransport);
+    }
+}

+ 26 - 0
src/main/java/io/github/qifan777/knowledge/domain/po/SearchDocument.java

@@ -0,0 +1,26 @@
+package io.github.qifan777.knowledge.domain.po;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.babyfish.jimmer.sql.ast.PropExpression;
+
+import java.util.List;
+
+/**
+ * <p>
+ *
+ * </p>
+ *
+ * @author taohongrun
+ * @since 2025/3/17
+ */
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class SearchDocument {
+    String title;
+    String url;
+    List<String> permissions; // 修改: 将 String 类型改为 List<String> 类型
+    String file_type;
+}