diff --git a/common/metadata_utils.py b/common/metadata_utils.py index cbe6dfe7c..aab00df8a 100644 --- a/common/metadata_utils.py +++ b/common/metadata_utils.py @@ -145,11 +145,22 @@ async def apply_meta_data_filter( if not doc_ids: return None elif method == "semi_auto": - selected_keys = meta_data_filter.get("semi_auto", []) + selected_keys = [] + constraints = {} + for item in meta_data_filter.get("semi_auto", []): + if isinstance(item, str): + selected_keys.append(item) + elif isinstance(item, dict): + key = item.get("key") + op = item.get("op") + selected_keys.append(key) + if op: + constraints[key] = op + if selected_keys: filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question, constraints=constraints) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: return None diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index 494e1915b..369d0448d 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -15,18 +15,17 @@ # import infinity.rag_tokenizer -from common import settings - - class RagTokenizer(infinity.rag_tokenizer.RagTokenizer): def tokenize(self, line: str) -> str: + from common import settings # moved from the top of the file to avoid circular import if settings.DOC_ENGINE_INFINITY: return line else: return super().tokenize(line) def fine_grained_tokenize(self, tks: str) -> str: + from common import settings # moved from the top of the file to avoid circular import if settings.DOC_ENGINE_INFINITY: return tks else: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 08c1c5c08..d6cd6de51 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -20,7 +20,6 @@ import math from collections import OrderedDict, defaultdict from dataclasses import dataclass -from rag.prompts.generator import relevant_chunks_with_toc from rag.nlp import rag_tokenizer, query import numpy as np from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByExpr, DocStoreConnection @@ -591,6 +590,7 @@ class Dealer: return {a.replace(".", "_"): max(1, c) for a, c in tag_fea} async def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6): + from rag.prompts.generator import relevant_chunks_with_toc # moved from the top of the file to avoid circular import if not chunks: return [] idx_nms = [index_name(tid) for tid in tenant_ids] diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index d3d7b65a6..609f2a6bc 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -467,7 +467,7 @@ async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summ return re.sub(r"^.*", "", ans, flags=re.DOTALL) -async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict: +async def gen_meta_filter(chat_mdl, meta_data: dict, query: str, constraints: dict = None) -> dict: meta_data_structure = {} for key, values in meta_data.items(): meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values @@ -475,7 +475,8 @@ async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict: sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render( current_date=datetime.datetime.today().strftime('%Y-%m-%d'), metadata_keys=json.dumps(meta_data_structure), - user_question=query + user_question=query, + constraints=json.dumps(constraints) if constraints else None ) user_prompt = "Generate filters:" ans = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}]) diff --git a/rag/prompts/meta_filter.md b/rag/prompts/meta_filter.md index 203291071..28aff93a2 100644 --- a/rag/prompts/meta_filter.md +++ b/rag/prompts/meta_filter.md @@ -18,12 +18,17 @@ You are a metadata filtering condition generator. Analyze the user's question an 3. **Operator Guide**: - - Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"] + - Use these operators only: ["contains", "not contains","in", "not in", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"] - Date ranges: Break into two conditions (≥ start_date AND < next_month_start) - Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠") - Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01]) -4. **Processing Steps**: +4. **Operator Constraints**: + - If `constraints` are provided, you MUST use the specified operator for the corresponding key. + - Example Constraints: `{"price": ">", "author": "="}` + - If a key is not in `constraints`, choose the most appropriate operator. + +5. **Processing Steps**: a) Identify ALL filterable attributes in the query (both explicit and implicit) b) For dates: - Infer missing year from current date if needed @@ -34,7 +39,7 @@ You are a metadata filtering condition generator. Analyze the user's question an - Attribute doesn't exist in metadata - Value has no match in metadata -5. **Example A**: +6. **Example A**: - User query: "上市日期七月份的有哪些新品,不要蓝色的,只看鞋子和帽子" - Metadata: { "color": {...}, "listing_date": {...} } - Output: @@ -48,7 +53,7 @@ You are a metadata filtering condition generator. Analyze the user's question an ] } -6. **Example B**: +7. **Example B**: - User query: "It must be from China or India. Otherwise, it must not be blue or red." - Metadata: { "color": {...}, "country": {...} } - @@ -61,7 +66,7 @@ You are a metadata filtering condition generator. Analyze the user's question an ] } -7. **Final Output**: +8. **Final Output**: - ONLY output valid JSON dictionary - NO additional text/explanations - Json schema is as following: @@ -131,4 +136,7 @@ You are a metadata filtering condition generator. Analyze the user's question an - Today's date: {{ current_date }} - Available metadata keys: {{ metadata_keys }} - User query: "{{ user_question }}" +{% if constraints %} +- Operator constraints: {{ constraints }} +{% endif %} diff --git a/test/unit_test/common/test_apply_semi_auto_meta_data_filter.py b/test/unit_test/common/test_apply_semi_auto_meta_data_filter.py new file mode 100644 index 000000000..165e283aa --- /dev/null +++ b/test/unit_test/common/test_apply_semi_auto_meta_data_filter.py @@ -0,0 +1,53 @@ +import pytest +from common.metadata_utils import apply_meta_data_filter +from unittest.mock import MagicMock, AsyncMock, patch + +@pytest.mark.asyncio +async def test_apply_meta_data_filter_semi_auto_key(): + meta_data_filter = { + "method": "semi_auto", + "semi_auto": ["key1", "key2"] + } + metas = { + "key1": {"val1": ["doc1"]}, + "key2": {"val2": ["doc2"]} + } + question = "find val1" + + chat_mdl = MagicMock() + + with patch("rag.prompts.generator.gen_meta_filter", new_callable=AsyncMock) as mock_gen: + mock_gen.return_value = {"conditions": [{"key": "key1", "op": "=", "value": "val1"}], "logic": "and"} + + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl) + assert doc_ids == ["doc1"] + + # Check that constraints is an empty dict by default for legacy + mock_gen.assert_called_once() + args, kwargs = mock_gen.call_args + assert kwargs["constraints"] == {} + +@pytest.mark.asyncio +async def test_apply_meta_data_filter_semi_auto_key_and_operator(): + meta_data_filter = { + "method": "semi_auto", + "semi_auto": [{"key": "key1", "op": ">"}, "key2"] + } + metas = { + "key1": {"10": ["doc1"]}, + "key2": {"val2": ["doc2"]} + } + question = "find key1 > 5" + + chat_mdl = MagicMock() + + with patch("rag.prompts.generator.gen_meta_filter", new_callable=AsyncMock) as mock_gen: + mock_gen.return_value = {"conditions": [{"key": "key1", "op": ">", "value": "5"}], "logic": "and"} + + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl) + assert doc_ids == ["doc1"] + + # Check that constraints are correctly passed + mock_gen.assert_called_once() + args, kwargs = mock_gen.call_args + assert kwargs["constraints"] == {"key1": ">"} diff --git a/web/src/components/metadata-filter/index.tsx b/web/src/components/metadata-filter/index.tsx index 10c7b1e6b..988ad8b8c 100644 --- a/web/src/components/metadata-filter/index.tsx +++ b/web/src/components/metadata-filter/index.tsx @@ -26,7 +26,17 @@ export const MetadataFilterSchema = { }), ) .optional(), - semi_auto: z.array(z.string()).optional(), + semi_auto: z + .array( + z.union([ + z.string(), + z.object({ + key: z.string(), + op: z.string().optional(), + }), + ]), + ) + .optional(), }) .optional(), }; diff --git a/web/src/components/metadata-filter/metadata-semi-auto-fields.tsx b/web/src/components/metadata-filter/metadata-semi-auto-fields.tsx index 9bab0ebbb..57fb686c7 100644 --- a/web/src/components/metadata-filter/metadata-semi-auto-fields.tsx +++ b/web/src/components/metadata-filter/metadata-semi-auto-fields.tsx @@ -1,10 +1,4 @@ import { Button } from '@/components/ui/button'; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuTrigger, -} from '@/components/ui/dropdown-menu'; import { FormControl, FormField, @@ -12,12 +6,13 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; -import { Input } from '@/components/ui/input'; +import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options'; import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request'; import { Plus, X } from 'lucide-react'; -import { useCallback } from 'react'; +import { useCallback, useMemo } from 'react'; import { useFieldArray, useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { SelectWithSearch } from '../originui/select-with-search'; export function MetadataSemiAutoFields({ kbIds, @@ -36,59 +31,86 @@ export function MetadataSemiAutoFields({ control: form.control, }); - const add = useCallback( - (key: string) => () => { - append(key); - }, - [append], - ); + const add = useCallback(() => { + append({ key: '', op: '' }); + }, [append]); + + const switchOperatorOptions = useBuildSwitchOperatorOptions(); + + const autoOption = { label: t('chat.meta.auto'), value: '' }; + + const metadataOptions = useMemo(() => { + return Object.keys(metadata.data || {}).map((key) => ({ + label: key, + value: key, + })); + }, [metadata.data]); return (
{t('chat.metadataKeys')} - - - - - - {Object.keys(metadata.data).map((key, idx) => { - return ( - - {key} - - ); - })} - - +
-
+
{fields.map((field, index) => { - const typeField = `${name}.${index}`; + const keyField = `${name}.${index}.key`; + const opField = `${name}.${index}.op`; return ( -
-
- ( - - - - - - - )} - /> -
-
); diff --git a/web/src/components/originui/select-with-search.tsx b/web/src/components/originui/select-with-search.tsx index 5990bf540..5d81471cb 100644 --- a/web/src/components/originui/select-with-search.tsx +++ b/web/src/components/originui/select-with-search.tsx @@ -159,9 +159,9 @@ export const SelectWithSearch = forwardRef< triggerClassName, )} > - {value ? ( + {selectLabel || value ? ( - {selectLabel} + {selectLabel || value} ) : ( {placeholder}