From 153e430b0035ecda3c86695bc48dfe1989525215 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 12 Aug 2025 14:12:56 +0800 Subject: [PATCH] Feat: add meta data filter. (#9405) ### What problem does this PR solve? #8531 #7417 #6761 #6573 #6477 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/dialog_app.py | 2 ++ api/apps/document_app.py | 5 +++ api/apps/kb_app.py | 15 ++++++++ api/db/db_models.py | 5 +++ api/db/services/dialog_service.py | 53 ++++++++++++++++++++++++++++- api/db/services/document_service.py | 19 +++++++++++ rag/nlp/search.py | 2 -- rag/prompts/meta_filter.md | 53 +++++++++++++++++++++++++++++ rag/prompts/prompts.py | 18 ++++++++++ rag/svr/task_executor.py | 2 +- rag/utils/s3_conn.py | 14 ++++++++ 11 files changed, 184 insertions(+), 4 deletions(-) create mode 100644 rag/prompts/meta_filter.md diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 4f55381a1..c5c48b6e1 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -51,6 +51,7 @@ def set_dialog(): similarity_threshold = req.get("similarity_threshold", 0.1) vector_similarity_weight = req.get("vector_similarity_weight", 0.3) llm_setting = req.get("llm_setting", {}) + meta_data_filter = req.get("meta_data_filter", {}) prompt_config = req["prompt_config"] if not is_create: @@ -85,6 +86,7 @@ def set_dialog(): "llm_id": llm_id, "llm_setting": llm_setting, "prompt_config": prompt_config, + "meta_data_filter": meta_data_filter, "top_n": top_n, "top_k": top_k, "rerank_id": rerank_id, diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 7f01f37a8..5d1531f9e 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -681,6 +681,11 @@ def set_meta(): return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) try: meta = json.loads(req["meta"]) + if not isinstance(meta, dict): + return get_json_result(data=False, message="Only dictionary type supported.", code=settings.RetCode.ARGUMENT_ERROR) + for k,v in meta.items(): + if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float): + return get_json_result(data=False, message=f"The type is not supported: {v}", code=settings.RetCode.ARGUMENT_ERROR) except Exception as e: return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR) if not isinstance(meta, dict): diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index fc00d0ac1..2e86a31bd 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -351,6 +351,7 @@ def knowledge_graph(kb_id): obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] return get_json_result(data=obj) + @manager.route('//knowledge_graph', methods=['DELETE']) # noqa: F821 @login_required def delete_knowledge_graph(kb_id): @@ -364,3 +365,17 @@ def delete_knowledge_graph(kb_id): settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) return get_json_result(data=True) + + +@manager.route("/get_meta", methods=["GET"]) # noqa: F821 +@login_required +def get_meta(): + kb_ids = request.args.get("kb_ids", "").split(",") + for kb_id in kb_ids: + if not KnowledgebaseService.accessible(kb_id, current_user.id): + return get_json_result( + data=False, + message='No authorization.', + code=settings.RetCode.AUTHENTICATION_ERROR + ) + return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids)) diff --git a/api/db/db_models.py b/api/db/db_models.py index 038fc8406..6438cf39c 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -744,6 +744,7 @@ class Dialog(DataBaseModel): null=False, default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?", "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"}, ) + meta_data_filter = JSONField(null=True, default={}) similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) @@ -1015,4 +1016,8 @@ def migrate_db(): migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors"))) except Exception: pass + try: + migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={}))) + except Exception: + pass logging.disable(logging.NOTSET) \ No newline at end of file diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 5f9371e11..cbbc8ba9e 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -30,6 +30,7 @@ from api import settings from api.db import LLMType, ParserType, StatusEnum from api.db.db_models import DB, Dialog from api.db.services.common_service import CommonService +from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle, TenantLLMService @@ -38,6 +39,7 @@ from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in +from rag.prompts.prompts import gen_meta_filter from rag.utils import num_tokens_from_string, rmSpace from rag.utils.tavily_conn import Tavily @@ -250,6 +252,46 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): return answer, idx +def meta_filter(metas: dict, filters: list[dict]): + doc_ids = [] + def filter_out(v2docs, operator, value): + nonlocal doc_ids + for input,docids in v2docs.items(): + try: + input = float(input) + value = float(value) + except Exception: + input = str(input) + value = str(value) + + for conds in [ + (operator == "contains", str(value).lower() in str(input).lower()), + (operator == "not contains", str(value).lower() not in str(input).lower()), + (operator == "start with", str(input).lower().startswith(str(value).lower())), + (operator == "end with", str(input).lower().endswith(str(value).lower())), + (operator == "empty", not input), + (operator == "not empty", input), + (operator == "=", input == value), + (operator == "≠", input != value), + (operator == ">", input > value), + (operator == "<", input < value), + (operator == "≥", input >= value), + (operator == "≤", input <= value), + ]: + try: + if all(conds): + doc_ids.extend(docids) + except Exception: + pass + + for k, v2docs in metas.items(): + for f in filters: + if k != f["key"]: + continue + filter_out(v2docs, f["op"], f["value"]) + return doc_ids + + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): @@ -287,9 +329,10 @@ def chat(dialog, messages, stream=True, **kwargs): retriever = settings.retrievaler questions = [m["content"] for m in messages if m["role"] == "user"][-3:] - attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None + attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] + prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) # try to use sql if field mapping is good to go @@ -316,6 +359,14 @@ def chat(dialog, messages, stream=True, **kwargs): if prompt_config.get("cross_languages"): questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] + if dialog.meta_data_filter: + metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) + if dialog.meta_data_filter.get("method") == "auto": + filters = gen_meta_filter(chat_mdl, metas, questions[-1]) + attachments.extend(meta_filter(metas, filters)) + elif dialog.meta_data_filter.get("method") == "manual": + attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"])) + if prompt_config.get("keyword", False): questions[-1] += keyword_extraction(chat_mdl, questions[-1]) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 303a357aa..6cd42735d 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -574,6 +574,25 @@ class DocumentService(CommonService): def update_meta_fields(cls, doc_id, meta_fields): return cls.update_by_id(doc_id, {"meta_fields": meta_fields}) + @classmethod + @DB.connection_context() + def get_meta_by_kbs(cls, kb_ids): + fields = [ + cls.model.id, + cls.model.meta_fields, + ] + meta = {} + for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)): + doc_id = r.id + for k,v in r.meta_fields.items(): + if k not in meta: + meta[k] = {} + v = str(v) + if v not in meta[k]: + meta[k][v] = [] + meta[k][v].append(doc_id) + return meta + @classmethod @DB.connection_context() def update_progress(cls): diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 660fab0bc..b1617b9a7 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -383,8 +383,6 @@ class Dealer: vector_column = f"q_{dim}_vec" zero_vector = [0.0] * dim sim_np = np.array(sim) - if doc_ids: - similarity_threshold = 0 filtered_count = (sim_np >= similarity_threshold).sum() ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error for i in idx: diff --git a/rag/prompts/meta_filter.md b/rag/prompts/meta_filter.md new file mode 100644 index 000000000..89e322fe5 --- /dev/null +++ b/rag/prompts/meta_filter.md @@ -0,0 +1,53 @@ +You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules: + +1. **Metadata Structure**: + - Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs. + - Example: + { + "color": {"red": ["doc1"], "blue": ["doc2"]}, + "listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]} + } + +2. **Output Requirements**: + - Always output a JSON array of filter objects + - Each object must have: + "key": (metadata attribute name), + "value": (string value to compare), + "op": (operator from allowed list) + +3. **Operator Guide**: + - Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"] + - Date ranges: Break into two conditions (≥ start_date AND < next_month_start) + - Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠") + - Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01]) + +4. **Processing Steps**: + a) Identify ALL filterable attributes in the query (both explicit and implicit) + b) For dates: + - Infer missing year from current date if needed + - Always format dates as "YYYY-MM-DD" + - Convert ranges: [≥ start, < end] + c) For values: Match EXACTLY to metadata's value keys + d) Skip conditions if: + - Attribute doesn't exist in metadata + - Value has no match in metadata + +5. **Example**: + - User query: "上市日期七月份的有哪些商品,不要蓝色的" + - Metadata: { "color": {...}, "listing_date": {...} } + - Output: + [ + {"key": "listing_date", "value": "2025-07-01", "op": "≥"}, + {"key": "listing_date", "value": "2025-08-01", "op": "<"}, + {"key": "color", "value": "blue", "op": "≠"} + ] + +6. **Final Output**: + - ONLY output valid JSON array + - NO additional text/explanations + +**Current Task**: +- Today's date: {{current_date}} +- Available metadata keys: {{metadata_keys}} +- User query: "{{user_question}}" + diff --git a/rag/prompts/prompts.py b/rag/prompts/prompts.py index 75c9369b8..e3ce3e457 100644 --- a/rag/prompts/prompts.py +++ b/rag/prompts/prompts.py @@ -149,6 +149,7 @@ NEXT_STEP = load_prompt("next_step") REFLECT = load_prompt("reflect") SUMMARY4MEMORY = load_prompt("summary4memory") RANK_MEMORY = load_prompt("rank_memory") +META_FILTER = load_prompt("meta_filter") PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True) @@ -413,3 +414,20 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>") return re.sub(r"^.*", "", ans, flags=re.DOTALL) + +def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list: + sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render( + current_date=datetime.datetime.today().strftime('%Y-%m-%d'), + metadata_keys=json.dumps(meta_data), + user_question=query + ) + user_prompt = "Generate filters:" + ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}]) + ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) + try: + ans = json_repair.loads(ans) + assert isinstance(ans, list), ans + return ans + except Exception: + logging.exception(f"Loading json failure: {ans}") + return [] \ No newline at end of file diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index f089b0f4e..649c7e95e 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -444,7 +444,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): tts = np.concatenate([vts for _ in range(len(tts))], axis=0) tk_count += c - @timeout(5) + @timeout(60) def batch_encode(txts): nonlocal mdl return mdl.encode([truncate(c, mdl.max_length-10) for c in txts]) diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py index 9b18d59bd..038e47135 100644 --- a/rag/utils/s3_conn.py +++ b/rag/utils/s3_conn.py @@ -190,3 +190,17 @@ class RAGFlowS3: self.__open__() time.sleep(1) return + + @use_prefix_path + @use_default_bucket + def rm_bucket(self, bucket, *args, **kwargs): + for conn in self.conn: + try: + if not conn.bucket_exists(bucket): + continue + for o in conn.list_objects_v2(Bucket=bucket): + conn.delete_object(bucket, o.object_name) + conn.delete_bucket(Bucket=bucket) + return + except Exception as e: + logging.error(f"Fail rm {bucket}: " + str(e))