From 8d8a5f73b6d252f834f96cb160bc1ab36020fd1d Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 25 Aug 2025 18:29:24 +0800 Subject: [PATCH] Fix: meta data filter with AND logic operations. (#9687) ### What problem does this PR solve? Close #9648 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/services/dialog_service.py | 18 +++++++++++++----- api/utils/api_utils.py | 11 +++++++++-- graphrag/entity_resolution.py | 7 +++++-- .../general/community_reports_extractor.py | 4 +++- graphrag/general/index.py | 5 ++++- graphrag/utils.py | 9 ++++++--- rag/svr/task_executor.py | 7 +------ 7 files changed, 41 insertions(+), 20 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index fc369f2c1..e7e6f9038 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -256,10 +256,10 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): def meta_filter(metas: dict, filters: list[dict]): - doc_ids = [] + doc_ids = set([]) def filter_out(v2docs, operator, value): - nonlocal doc_ids + ids = [] for input, docids in v2docs.items(): try: input = float(input) @@ -284,16 +284,24 @@ def meta_filter(metas: dict, filters: list[dict]): ]: try: if all(conds): - doc_ids.extend(docids) + ids.extend(docids) + break except Exception: pass + return ids for k, v2docs in metas.items(): for f in filters: if k != f["key"]: continue - filter_out(v2docs, f["op"], f["value"]) - return doc_ids + ids = filter_out(v2docs, f["op"], f["value"]) + if not doc_ids: + doc_ids = set(ids) + else: + doc_ids = doc_ids & set(ids) + if not doc_ids: + return [] + return list(doc_ids) def chat(dialog, messages, stream=True, **kwargs): diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 4de7306b1..836a9da6c 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -17,6 +17,7 @@ import asyncio import functools import json import logging +import os import queue import random import threading @@ -667,7 +668,10 @@ def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Option for a in range(attempts): try: - result = result_queue.get(timeout=seconds) + if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): + result = result_queue.get(timeout=seconds) + else: + result = result_queue.get() if isinstance(result, Exception): raise result return result @@ -682,7 +686,10 @@ def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Option for a in range(attempts): try: - with trio.fail_after(seconds): + if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): + with trio.fail_after(seconds): + return await func(*args, **kwargs) + else: return await func(*args, **kwargs) except trio.TooSlowError: if a < attempts - 1: diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 9cdb16bb9..01478d760 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -15,6 +15,7 @@ # import logging import itertools +import os import re from dataclasses import dataclass from typing import Any, Callable @@ -106,7 +107,8 @@ class EntityResolution(Extractor): nonlocal remain_candidates_to_resolve, callback async with semaphore: try: - with trio.move_on_after(280) as cancel_scope: + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") + with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: await self._resolve_candidate(candidate_batch, result_set, result_lock) remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") @@ -169,7 +171,8 @@ class EntityResolution(Extractor): logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: try: - with trio.move_on_after(280) as cancel_scope: + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") + with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}) if cancel_scope.cancelled_caught: logging.warning("_resolve_candidate._chat timeout, skipping...") diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 32b0d6a0f..6f9fd65b9 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -7,6 +7,7 @@ Reference: import logging import json +import os import re from typing import Callable from dataclasses import dataclass @@ -51,6 +52,7 @@ class CommunityReportsExtractor(Extractor): self._max_report_length = max_report_length or 1500 async def __call__(self, graph: nx.Graph, callback: Callable | None = None): + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") for node_degree in graph.degree: graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) @@ -92,7 +94,7 @@ class CommunityReportsExtractor(Extractor): text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) async with chat_limiter: try: - with trio.move_on_after(180) as cancel_scope: + with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope: response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}) if cancel_scope.cancelled_caught: logging.warning("extract_community_report._chat timeout, skipping...") diff --git a/graphrag/general/index.py b/graphrag/general/index.py index ac7fb9607..e5150c54a 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -15,6 +15,8 @@ # import json import logging +import os + import networkx as nx import trio @@ -49,6 +51,7 @@ async def run_graphrag( embedding_model, callback, ): + enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] @@ -57,7 +60,7 @@ async def run_graphrag( ): chunks.append(d["content_with_weight"]) - with trio.fail_after(max(120, len(chunks)*60*10)): + with trio.fail_after(max(120, len(chunks)*60*10) if enable_timeout_assertion else 10000000000): subgraph = await generate_subgraph( LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" diff --git a/graphrag/utils.py b/graphrag/utils.py index 4d2e79858..fbe391f8f 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -307,6 +307,7 @@ def chunk_id(chunk): async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): global chat_limiter + enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") chunk = { "id": get_uuid(), "important_kwd": [ent_name], @@ -324,7 +325,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): ebd = get_embed_cache(embd_mdl.llm_name, ent_name) if ebd is None: async with chat_limiter: - with trio.fail_after(3): + with trio.fail_after(3 if enable_timeout_assertion else 30000000): ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name])) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, ent_name, ebd) @@ -362,6 +363,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): + enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") chunk = { "id": get_uuid(), "from_entity_kwd": from_ent_name, @@ -380,7 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, ebd = get_embed_cache(embd_mdl.llm_name, txt) if ebd is None: async with chat_limiter: - with trio.fail_after(3): + with trio.fail_after(3 if enable_timeout_assertion else 300000000): ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"])) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, txt, ebd) @@ -514,9 +516,10 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") start = now + enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - with trio.fail_after(3): + with trio.fail_after(3 if enable_timeout_assertion else 30000000): doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id)) if b % 100 == es_bulk_size and callback: callback(msg=f"Insert chunks: {b}/{len(chunks)}") diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 27114154b..078d4f296 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -21,7 +21,7 @@ import sys import threading import time -from api.utils.api_utils import timeout, is_strong_enough +from api.utils.api_utils import timeout from api.utils.log_utils import init_root_logger, get_project_base_directory from graphrag.general.index import run_graphrag from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache @@ -478,8 +478,6 @@ async def embedding(docs, mdl, parser_config=None, callback=None): @timeout(3600) async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): - # Pressure test for GraphRAG task - await is_strong_enough(chat_mdl, embd_mdl) chunks = [] vctr_nm = "q_%d_vec"%vector_size for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], @@ -553,7 +551,6 @@ async def do_handle_task(task): try: # bind embedding model embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language) - await is_strong_enough(None, embedding_model) vts, _ = embedding_model.encode(["ok"]) vector_size = len(vts[0]) except Exception as e: @@ -568,7 +565,6 @@ async def do_handle_task(task): if task.get("task_type", "") == "raptor": # bind LLM for raptor chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) - await is_strong_enough(chat_model, None) # run RAPTOR async with kg_limiter: chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) @@ -580,7 +576,6 @@ async def do_handle_task(task): graphrag_conf = task["kb_parser_config"].get("graphrag", {}) start_ts = timer() chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) - await is_strong_enough(chat_model, None) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) async with kg_limiter: