diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py
index e2e24ea35..f0823be69 100644
--- a/agent/tools/retrieval.py
+++ b/agent/tools/retrieval.py
@@ -18,12 +18,14 @@ import re
from abc import ABC
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.db import LLMType
+from api.db.services.document_service import DocumentService
+from api.db.services.dialog_service import meta_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from api.utils.api_utils import timeout
from rag.app.tag import label_question
-from rag.prompts.generator import cross_languages, kb_prompt
+from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
class RetrievalParam(ToolParamBase):
@@ -58,6 +60,7 @@ class RetrievalParam(ToolParamBase):
self.use_kg = False
self.cross_languages = []
self.toc_enhance = False
+ self.meta_data_filter={}
def check(self):
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
@@ -117,6 +120,21 @@ class Retrieval(ToolBase, ABC):
vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars)
+
+ doc_ids=[]
+ if self._param.meta_data_filter!={}:
+ metas = DocumentService.get_meta_by_kbs(kb_ids)
+ if self._param.meta_data_filter.get("method") == "auto":
+ chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
+ filters = gen_meta_filter(chat_mdl, metas, query)
+ doc_ids.extend(meta_filter(metas, filters))
+ if not doc_ids:
+ doc_ids = None
+ elif self._param.meta_data_filter.get("method") == "manual":
+ doc_ids.extend(meta_filter(metas, self._param.meta_data_filter["manual"]))
+ if not doc_ids:
+ doc_ids = None
+
if self._param.cross_languages:
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
@@ -131,6 +149,7 @@ class Retrieval(ToolBase, ABC):
self._param.top_n,
self._param.similarity_threshold,
1 - self._param.keywords_similarity_weight,
+ doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs),
diff --git a/web/src/pages/agent/form/retrieval-form/next.tsx b/web/src/pages/agent/form/retrieval-form/next.tsx
index 67322b2b5..18ccb7ff2 100644
--- a/web/src/pages/agent/form/retrieval-form/next.tsx
+++ b/web/src/pages/agent/form/retrieval-form/next.tsx
@@ -2,6 +2,10 @@ import { Collapse } from '@/components/collapse';
import { CrossLanguageFormField } from '@/components/cross-language-form-field';
import { FormContainer } from '@/components/form-container';
import { KnowledgeBaseFormField } from '@/components/knowledge-base-item';
+import {
+ MetadataFilter,
+ MetadataFilterSchema,
+} from '@/components/metadata-filter';
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { RerankFormFields } from '@/components/rerank';
import { SimilaritySliderFormField } from '@/components/similarity-slider';
@@ -41,6 +45,7 @@ export const RetrievalPartialSchema = {
cross_languages: z.array(z.string()),
use_kg: z.boolean(),
toc_enhance: z.boolean(),
+ ...MetadataFilterSchema,
};
export const FormSchema = z.object({
@@ -118,6 +123,7 @@ function RetrievalForm({ node }: INextOperatorForm) {
>
+
diff --git a/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx b/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx
index c1031081e..9ac763b59 100644
--- a/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx
+++ b/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx
@@ -2,6 +2,7 @@ import { Collapse } from '@/components/collapse';
import { CrossLanguageFormField } from '@/components/cross-language-form-field';
import { FormContainer } from '@/components/form-container';
import { KnowledgeBaseFormField } from '@/components/knowledge-base-item';
+import { MetadataFilter } from '@/components/metadata-filter';
import { RerankFormFields } from '@/components/rerank';
import { SimilaritySliderFormField } from '@/components/similarity-slider';
import { TOCEnhanceFormField } from '@/components/toc-enhance-form-field';
@@ -51,6 +52,7 @@ const RetrievalForm = () => {
>
+