mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: meta data filter with AND logic operations. (#9687)
### What problem does this PR solve? Close #9648 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -256,10 +256,10 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|||||||
|
|
||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict]):
|
def meta_filter(metas: dict, filters: list[dict]):
|
||||||
doc_ids = []
|
doc_ids = set([])
|
||||||
|
|
||||||
def filter_out(v2docs, operator, value):
|
def filter_out(v2docs, operator, value):
|
||||||
nonlocal doc_ids
|
ids = []
|
||||||
for input, docids in v2docs.items():
|
for input, docids in v2docs.items():
|
||||||
try:
|
try:
|
||||||
input = float(input)
|
input = float(input)
|
||||||
@ -284,16 +284,24 @@ def meta_filter(metas: dict, filters: list[dict]):
|
|||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
if all(conds):
|
if all(conds):
|
||||||
doc_ids.extend(docids)
|
ids.extend(docids)
|
||||||
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
return ids
|
||||||
|
|
||||||
for k, v2docs in metas.items():
|
for k, v2docs in metas.items():
|
||||||
for f in filters:
|
for f in filters:
|
||||||
if k != f["key"]:
|
if k != f["key"]:
|
||||||
continue
|
continue
|
||||||
filter_out(v2docs, f["op"], f["value"])
|
ids = filter_out(v2docs, f["op"], f["value"])
|
||||||
return doc_ids
|
if not doc_ids:
|
||||||
|
doc_ids = set(ids)
|
||||||
|
else:
|
||||||
|
doc_ids = doc_ids & set(ids)
|
||||||
|
if not doc_ids:
|
||||||
|
return []
|
||||||
|
return list(doc_ids)
|
||||||
|
|
||||||
|
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import asyncio
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import queue
|
import queue
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
@ -667,7 +668,10 @@ def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Option
|
|||||||
|
|
||||||
for a in range(attempts):
|
for a in range(attempts):
|
||||||
try:
|
try:
|
||||||
result = result_queue.get(timeout=seconds)
|
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||||
|
result = result_queue.get(timeout=seconds)
|
||||||
|
else:
|
||||||
|
result = result_queue.get()
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
raise result
|
raise result
|
||||||
return result
|
return result
|
||||||
@ -682,7 +686,10 @@ def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Option
|
|||||||
|
|
||||||
for a in range(attempts):
|
for a in range(attempts):
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(seconds):
|
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||||
|
with trio.fail_after(seconds):
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
else:
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except trio.TooSlowError:
|
except trio.TooSlowError:
|
||||||
if a < attempts - 1:
|
if a < attempts - 1:
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
@ -106,7 +107,8 @@ class EntityResolution(Extractor):
|
|||||||
nonlocal remain_candidates_to_resolve, callback
|
nonlocal remain_candidates_to_resolve, callback
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
try:
|
try:
|
||||||
with trio.move_on_after(280) as cancel_scope:
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
|
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||||
await self._resolve_candidate(candidate_batch, result_set, result_lock)
|
await self._resolve_candidate(candidate_batch, result_set, result_lock)
|
||||||
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
||||||
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
|
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
|
||||||
@ -169,7 +171,8 @@ class EntityResolution(Extractor):
|
|||||||
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
|
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
try:
|
try:
|
||||||
with trio.move_on_after(280) as cancel_scope:
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
|
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||||
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
||||||
if cancel_scope.cancelled_caught:
|
if cancel_scope.cancelled_caught:
|
||||||
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
||||||
|
|||||||
@ -7,6 +7,7 @@ Reference:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -51,6 +52,7 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
self._max_report_length = max_report_length or 1500
|
self._max_report_length = max_report_length or 1500
|
||||||
|
|
||||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||||
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
for node_degree in graph.degree:
|
for node_degree in graph.degree:
|
||||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||||
|
|
||||||
@ -92,7 +94,7 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
try:
|
try:
|
||||||
with trio.move_on_after(180) as cancel_scope:
|
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
||||||
if cancel_scope.cancelled_caught:
|
if cancel_scope.cancelled_caught:
|
||||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||||
|
|||||||
@ -15,6 +15,8 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
@ -49,6 +51,7 @@ async def run_graphrag(
|
|||||||
embedding_model,
|
embedding_model,
|
||||||
callback,
|
callback,
|
||||||
):
|
):
|
||||||
|
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -57,7 +60,7 @@ async def run_graphrag(
|
|||||||
):
|
):
|
||||||
chunks.append(d["content_with_weight"])
|
chunks.append(d["content_with_weight"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks)*60*10)):
|
with trio.fail_after(max(120, len(chunks)*60*10) if enable_timeout_assertion else 10000000000):
|
||||||
subgraph = await generate_subgraph(
|
subgraph = await generate_subgraph(
|
||||||
LightKGExt
|
LightKGExt
|
||||||
if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general"
|
if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general"
|
||||||
|
|||||||
@ -307,6 +307,7 @@ def chunk_id(chunk):
|
|||||||
|
|
||||||
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||||
global chat_limiter
|
global chat_limiter
|
||||||
|
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"important_kwd": [ent_name],
|
"important_kwd": [ent_name],
|
||||||
@ -324,7 +325,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
|||||||
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
||||||
if ebd is None:
|
if ebd is None:
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
with trio.fail_after(3):
|
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
||||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
|
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
|
||||||
ebd = ebd[0]
|
ebd = ebd[0]
|
||||||
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
||||||
@ -362,6 +363,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
|||||||
|
|
||||||
|
|
||||||
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
||||||
|
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"from_entity_kwd": from_ent_name,
|
"from_entity_kwd": from_ent_name,
|
||||||
@ -380,7 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
|||||||
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
||||||
if ebd is None:
|
if ebd is None:
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
with trio.fail_after(3):
|
with trio.fail_after(3 if enable_timeout_assertion else 300000000):
|
||||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
|
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
|
||||||
ebd = ebd[0]
|
ebd = ebd[0]
|
||||||
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
||||||
@ -514,9 +516,10 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
|||||||
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
||||||
start = now
|
start = now
|
||||||
|
|
||||||
|
enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(chunks), es_bulk_size):
|
for b in range(0, len(chunks), es_bulk_size):
|
||||||
with trio.fail_after(3):
|
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||||
if b % 100 == es_bulk_size and callback:
|
if b % 100 == es_bulk_size and callback:
|
||||||
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from api.utils.api_utils import timeout, is_strong_enough
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||||
from graphrag.general.index import run_graphrag
|
from graphrag.general.index import run_graphrag
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
@ -478,8 +478,6 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
|
|
||||||
@timeout(3600)
|
@timeout(3600)
|
||||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||||
# Pressure test for GraphRAG task
|
|
||||||
await is_strong_enough(chat_mdl, embd_mdl)
|
|
||||||
chunks = []
|
chunks = []
|
||||||
vctr_nm = "q_%d_vec"%vector_size
|
vctr_nm = "q_%d_vec"%vector_size
|
||||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||||
@ -553,7 +551,6 @@ async def do_handle_task(task):
|
|||||||
try:
|
try:
|
||||||
# bind embedding model
|
# bind embedding model
|
||||||
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
|
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
|
||||||
await is_strong_enough(None, embedding_model)
|
|
||||||
vts, _ = embedding_model.encode(["ok"])
|
vts, _ = embedding_model.encode(["ok"])
|
||||||
vector_size = len(vts[0])
|
vector_size = len(vts[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -568,7 +565,6 @@ async def do_handle_task(task):
|
|||||||
if task.get("task_type", "") == "raptor":
|
if task.get("task_type", "") == "raptor":
|
||||||
# bind LLM for raptor
|
# bind LLM for raptor
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
await is_strong_enough(chat_model, None)
|
|
||||||
# run RAPTOR
|
# run RAPTOR
|
||||||
async with kg_limiter:
|
async with kg_limiter:
|
||||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||||
@ -580,7 +576,6 @@ async def do_handle_task(task):
|
|||||||
graphrag_conf = task["kb_parser_config"].get("graphrag", {})
|
graphrag_conf = task["kb_parser_config"].get("graphrag", {})
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
await is_strong_enough(chat_model, None)
|
|
||||||
with_resolution = graphrag_conf.get("resolution", False)
|
with_resolution = graphrag_conf.get("resolution", False)
|
||||||
with_community = graphrag_conf.get("community", False)
|
with_community = graphrag_conf.get("community", False)
|
||||||
async with kg_limiter:
|
async with kg_limiter:
|
||||||
|
|||||||
Reference in New Issue
Block a user