Fix: meta data filter with AND logic operations. (#9687)

### What problem does this PR solve?

Close #9648

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-08-25 18:29:24 +08:00
committed by GitHub
parent d0fa66f4d5
commit 8d8a5f73b6
7 changed files with 41 additions and 20 deletions

View File

@ -256,10 +256,10 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
def meta_filter(metas: dict, filters: list[dict]): def meta_filter(metas: dict, filters: list[dict]):
doc_ids = [] doc_ids = set([])
def filter_out(v2docs, operator, value): def filter_out(v2docs, operator, value):
nonlocal doc_ids ids = []
for input, docids in v2docs.items(): for input, docids in v2docs.items():
try: try:
input = float(input) input = float(input)
@ -284,16 +284,24 @@ def meta_filter(metas: dict, filters: list[dict]):
]: ]:
try: try:
if all(conds): if all(conds):
doc_ids.extend(docids) ids.extend(docids)
break
except Exception: except Exception:
pass pass
return ids
for k, v2docs in metas.items(): for k, v2docs in metas.items():
for f in filters: for f in filters:
if k != f["key"]: if k != f["key"]:
continue continue
filter_out(v2docs, f["op"], f["value"]) ids = filter_out(v2docs, f["op"], f["value"])
return doc_ids if not doc_ids:
doc_ids = set(ids)
else:
doc_ids = doc_ids & set(ids)
if not doc_ids:
return []
return list(doc_ids)
def chat(dialog, messages, stream=True, **kwargs): def chat(dialog, messages, stream=True, **kwargs):

View File

@ -17,6 +17,7 @@ import asyncio
import functools import functools
import json import json
import logging import logging
import os
import queue import queue
import random import random
import threading import threading
@ -667,7 +668,10 @@ def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Option
for a in range(attempts): for a in range(attempts):
try: try:
result = result_queue.get(timeout=seconds) if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
result = result_queue.get(timeout=seconds)
else:
result = result_queue.get()
if isinstance(result, Exception): if isinstance(result, Exception):
raise result raise result
return result return result
@ -682,7 +686,10 @@ def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Option
for a in range(attempts): for a in range(attempts):
try: try:
with trio.fail_after(seconds): if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds):
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except trio.TooSlowError: except trio.TooSlowError:
if a < attempts - 1: if a < attempts - 1:

View File

@ -15,6 +15,7 @@
# #
import logging import logging
import itertools import itertools
import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable from typing import Any, Callable
@ -106,7 +107,8 @@ class EntityResolution(Extractor):
nonlocal remain_candidates_to_resolve, callback nonlocal remain_candidates_to_resolve, callback
async with semaphore: async with semaphore:
try: try:
with trio.move_on_after(280) as cancel_scope: enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
await self._resolve_candidate(candidate_batch, result_set, result_lock) await self._resolve_candidate(candidate_batch, result_set, result_lock)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
@ -169,7 +171,8 @@ class EntityResolution(Extractor):
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter: async with chat_limiter:
try: try:
with trio.move_on_after(280) as cancel_scope: enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}) response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
if cancel_scope.cancelled_caught: if cancel_scope.cancelled_caught:
logging.warning("_resolve_candidate._chat timeout, skipping...") logging.warning("_resolve_candidate._chat timeout, skipping...")

View File

@ -7,6 +7,7 @@ Reference:
import logging import logging
import json import json
import os
import re import re
from typing import Callable from typing import Callable
from dataclasses import dataclass from dataclasses import dataclass
@ -51,6 +52,7 @@ class CommunityReportsExtractor(Extractor):
self._max_report_length = max_report_length or 1500 self._max_report_length = max_report_length or 1500
async def __call__(self, graph: nx.Graph, callback: Callable | None = None): async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
for node_degree in graph.degree: for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
@ -92,7 +94,7 @@ class CommunityReportsExtractor(Extractor):
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
async with chat_limiter: async with chat_limiter:
try: try:
with trio.move_on_after(180) as cancel_scope: with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}) response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
if cancel_scope.cancelled_caught: if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...") logging.warning("extract_community_report._chat timeout, skipping...")

View File

@ -15,6 +15,8 @@
# #
import json import json
import logging import logging
import os
import networkx as nx import networkx as nx
import trio import trio
@ -49,6 +51,7 @@ async def run_graphrag(
embedding_model, embedding_model,
callback, callback,
): ):
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
@ -57,7 +60,7 @@ async def run_graphrag(
): ):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks)*60*10)): with trio.fail_after(max(120, len(chunks)*60*10) if enable_timeout_assertion else 10000000000):
subgraph = await generate_subgraph( subgraph = await generate_subgraph(
LightKGExt LightKGExt
if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general"

View File

@ -307,6 +307,7 @@ def chunk_id(chunk):
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
global chat_limiter global chat_limiter
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),
"important_kwd": [ent_name], "important_kwd": [ent_name],
@ -324,7 +325,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
ebd = get_embed_cache(embd_mdl.llm_name, ent_name) ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None: if ebd is None:
async with chat_limiter: async with chat_limiter:
with trio.fail_after(3): with trio.fail_after(3 if enable_timeout_assertion else 30000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name])) ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd) set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
@ -362,6 +363,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),
"from_entity_kwd": from_ent_name, "from_entity_kwd": from_ent_name,
@ -380,7 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
ebd = get_embed_cache(embd_mdl.llm_name, txt) ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None: if ebd is None:
async with chat_limiter: async with chat_limiter:
with trio.fail_after(3): with trio.fail_after(3 if enable_timeout_assertion else 300000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"])) ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd) set_embed_cache(embd_mdl.llm_name, txt, ebd)
@ -514,9 +516,10 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now start = now
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3): with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
if b % 100 == es_bulk_size and callback: if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}") callback(msg=f"Insert chunks: {b}/{len(chunks)}")

View File

@ -21,7 +21,7 @@ import sys
import threading import threading
import time import time
from api.utils.api_utils import timeout, is_strong_enough from api.utils.api_utils import timeout
from api.utils.log_utils import init_root_logger, get_project_base_directory from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
@ -478,8 +478,6 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
@timeout(3600) @timeout(3600)
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
# Pressure test for GraphRAG task
await is_strong_enough(chat_mdl, embd_mdl)
chunks = [] chunks = []
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
@ -553,7 +551,6 @@ async def do_handle_task(task):
try: try:
# bind embedding model # bind embedding model
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language) embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
await is_strong_enough(None, embedding_model)
vts, _ = embedding_model.encode(["ok"]) vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0]) vector_size = len(vts[0])
except Exception as e: except Exception as e:
@ -568,7 +565,6 @@ async def do_handle_task(task):
if task.get("task_type", "") == "raptor": if task.get("task_type", "") == "raptor":
# bind LLM for raptor # bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await is_strong_enough(chat_model, None)
# run RAPTOR # run RAPTOR
async with kg_limiter: async with kg_limiter:
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
@ -580,7 +576,6 @@ async def do_handle_task(task):
graphrag_conf = task["kb_parser_config"].get("graphrag", {}) graphrag_conf = task["kb_parser_config"].get("graphrag", {})
start_ts = timer() start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await is_strong_enough(chat_model, None)
with_resolution = graphrag_conf.get("resolution", False) with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False) with_community = graphrag_conf.get("community", False)
async with kg_limiter: async with kg_limiter: