diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 52e295fd0..a0c990a81 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -136,6 +136,16 @@ class Retrieval(ToolBase, ABC): doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None + elif self._param.meta_data_filter.get("method") == "semi_auto": + selected_keys = self._param.meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, query) + doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not doc_ids: + doc_ids = None elif self._param.meta_data_filter.get("method") == "manual": filters = self._param.meta_data_filter["manual"] for flt in filters: diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index d96de64d0..37cd0c7a1 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -327,10 +327,44 @@ async def retrieval_test(): local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None + elif meta_data_filter.get("method") == "semi_auto": + selected_keys = meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None elif meta_data_filter.get("method") == "manual": local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) if meta_data_filter["manual"] and not local_doc_ids: local_doc_ids = ["-999"] + else: + meta_data_filter = req.get("meta_data_filter") + if meta_data_filter: + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(user_id, LLMType.CHAT) + filters: dict = gen_meta_filter(chat_mdl, metas, question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "semi_auto": + selected_keys = meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + chat_mdl = LLMBundle(user_id, LLMType.CHAT) + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "manual": + local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) + if meta_data_filter["manual"] and not local_doc_ids: + local_doc_ids = ["-999"] tenants = UserTenantService.query(user_id=user_id) for kb_id in kb_ids: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index fe4723984..27760f1a8 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -984,10 +984,45 @@ async def retrieval_test_embedded(): local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None + elif meta_data_filter.get("method") == "semi_auto": + selected_keys = meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, _question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None elif meta_data_filter.get("method") == "manual": local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) if meta_data_filter["manual"] and not local_doc_ids: local_doc_ids = ["-999"] + else: + meta_data_filter = req.get("meta_data_filter") + if meta_data_filter: + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + filters: dict = gen_meta_filter(chat_mdl, metas, question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "semi_auto": + selected_keys = meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "manual": + local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) + if meta_data_filter["manual"] and not local_doc_ids: + local_doc_ids = ["-999"] + tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 68e470eb8..88f61f190 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -425,6 +425,15 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not attachments: attachments = None + elif dialog.meta_data_filter.get("method") == "semi_auto": + selected_keys = dialog.meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, questions[-1]) + attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not attachments: + attachments = None elif dialog.meta_data_filter.get("method") == "manual": conds = dialog.meta_data_filter["manual"] attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and"))) @@ -834,6 +843,15 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None + elif meta_data_filter.get("method") == "semi_auto": + selected_keys = meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not doc_ids: + doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) if meta_data_filter["manual"] and not doc_ids: @@ -909,6 +927,15 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None + elif meta_data_filter.get("method") == "semi_auto": + selected_keys = meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not doc_ids: + doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) if meta_data_filter["manual"] and not doc_ids: diff --git a/web/src/components/metadata-filter/index.tsx b/web/src/components/metadata-filter/index.tsx index 48388e4c2..af72f5865 100644 --- a/web/src/components/metadata-filter/index.tsx +++ b/web/src/components/metadata-filter/index.tsx @@ -5,6 +5,7 @@ import { z } from 'zod'; import { SelectWithSearch } from '../originui/select-with-search'; import { RAGFlowFormItem } from '../ragflow-form'; import { MetadataFilterConditions } from './metadata-filter-conditions'; +import { MetadataSemiAutoFields } from './metadata-semi-auto-fields'; type MetadataFilterProps = { prefix?: string; @@ -25,6 +26,9 @@ export const MetadataFilterSchema = { }), ) .optional(), + semi_auto: z + .array(z.string()) // 修改为字符串数组 + .optional(), }) .optional(), }; @@ -76,6 +80,12 @@ export function MetadataFilter({ canReference={canReference} > )} + {hasKnowledge && metadata === DatasetMetadata.SemiAutomatic && ( + + )} ); } diff --git a/web/src/components/metadata-filter/metadata-semi-auto-fields.tsx b/web/src/components/metadata-filter/metadata-semi-auto-fields.tsx new file mode 100644 index 000000000..2948700ba --- /dev/null +++ b/web/src/components/metadata-filter/metadata-semi-auto-fields.tsx @@ -0,0 +1,100 @@ +import { Button } from '@/components/ui/button'; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu'; +import { + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request'; +import { Plus, X } from 'lucide-react'; +import { useCallback } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +export function MetadataSemiAutoFields({ + kbIds, + prefix = '', +}: { + kbIds: string[]; + prefix?: string; +}) { + const { t } = useTranslation(); + const form = useFormContext(); + const name = prefix + 'meta_data_filter.semi_auto'; + const metadata = useFetchKnowledgeMetadata(kbIds); + + const { fields, remove, append } = useFieldArray({ + name, + control: form.control, + }); + + const add = useCallback( + (key: string) => () => { + append(key); // 直接添加字符串而不是对象 + }, + [append], + ); + + return ( +
+
+ {t('chat.metadataKeys')} + + + + + + {Object.keys(metadata.data).map((key, idx) => { + return ( + + {key} + + ); + })} + + +
+
+ {fields.map((field, index) => { + // 修改字段名称以直接引用数组元素 + const typeField = `${name}.${index}`; + return ( +
+
+ ( + + + + + + + )} + /> +
+ +
+ ); + })} +
+
+ ); +} diff --git a/web/src/constants/chat.ts b/web/src/constants/chat.ts index 02d23b652..f102d9e79 100644 --- a/web/src/constants/chat.ts +++ b/web/src/constants/chat.ts @@ -36,5 +36,6 @@ export const EmptyConversationId = 'empty'; export enum DatasetMetadata { Disabled = 'disabled', Automatic = 'auto', + SemiAutomatic = 'semi_auto', Manual = 'manual', } diff --git a/web/src/interfaces/request/knowledge.ts b/web/src/interfaces/request/knowledge.ts index de1b00b43..ea60888de 100644 --- a/web/src/interfaces/request/knowledge.ts +++ b/web/src/interfaces/request/knowledge.ts @@ -7,6 +7,16 @@ export interface ITestRetrievalRequestBody { use_kg?: boolean; highlight?: boolean; kb_id?: string[]; + meta_data_filter?: { + logic?: string; + method?: string; + manual?: Array<{ + key: string; + op: string; + value: string; + }>; + semi_auto?: string[]; + }; } export interface IFetchKnowledgeListRequestBody { diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 010bf1efa..6d1049eda 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -737,11 +737,13 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s metadataTip: 'Metadata filtering is the process of using metadata attributes (such as tags, categories, or access permissions) to refine and control the retrieval of relevant information within a system.', conditions: 'Conditions', + metadataKeys: 'Filterable items', addCondition: 'Add condition', meta: { disabled: 'Disabled', auto: 'Automatic', manual: 'Manual', + semi_auto: 'Semi-automatic', }, cancel: 'Cancel', chatSetting: 'Chat setting', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 02ac586ef..c06ec2886 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -673,11 +673,13 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 metadataTip: '元数据过滤是使用元数据属性(例如标签、类别或访问权限)来优化和控制系统内相关信息检索的过程。', conditions: '条件', + metadataKeys: '可选过滤项', addCondition: '增加条件', meta: { disabled: '禁用', auto: '自动', manual: '手动', + semi_auto: '半自动', }, cancel: '取消', chatSetting: '聊天设置', diff --git a/web/src/pages/dataset/testing/testing-form.tsx b/web/src/pages/dataset/testing/testing-form.tsx index f6dc195b0..760dbdc01 100644 --- a/web/src/pages/dataset/testing/testing-form.tsx +++ b/web/src/pages/dataset/testing/testing-form.tsx @@ -7,14 +7,18 @@ import { z } from 'zod'; import { CrossLanguageFormField } from '@/components/cross-language-form-field'; import { FormContainer } from '@/components/form-container'; import { - initialTopKValue, + MetadataFilter, + MetadataFilterSchema, +} from '@/components/metadata-filter'; +import { RerankFormFields, + initialTopKValue, topKSchema, } from '@/components/rerank'; import { + SimilaritySliderFormField, initialSimilarityThresholdValue, initialVectorSimilarityWeightValue, - SimilaritySliderFormField, similarityThresholdSchema, vectorSimilarityWeightSchema, } from '@/components/similarity-slider'; @@ -33,6 +37,7 @@ import { trim } from 'lodash'; import { Send } from 'lucide-react'; import { useEffect } from 'react'; import { useTranslation } from 'react-i18next'; +import { useParams } from 'umi'; type TestingFormProps = Pick< ReturnType, @@ -45,6 +50,8 @@ export default function TestingForm({ setValues, }: TestingFormProps) { const { t } = useTranslation(); + const { id } = useParams(); // 正确解构出id参数 + const knowledgeBaseId = id; // 现在knowledgeBaseId是字符串类型 const formSchema = z.object({ question: z.string().min(1, { @@ -54,6 +61,8 @@ export default function TestingForm({ ...vectorSimilarityWeightSchema, ...topKSchema, use_kg: z.boolean().optional(), + kb_ids: z.array(z.string()).optional(), + ...MetadataFilterSchema, }); const form = useForm>({ @@ -63,6 +72,7 @@ export default function TestingForm({ ...initialVectorSimilarityWeightValue, ...initialTopKValue, use_kg: false, + kb_ids: [knowledgeBaseId], }, }); @@ -90,6 +100,8 @@ export default function TestingForm({ + {/* 添加元数据过滤组件 */} +