Feat: add fault-tolerant mechanism to RAPTOR (#11206)

### What problem does this PR solve?

Add fault-tolerant mechanism to RAPTOR.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-13 18:48:07 +08:00
committed by GitHub
parent 70a0f081f6
commit 908450509f
4 changed files with 86 additions and 149 deletions

View File

@ -3,15 +3,9 @@ import os
import threading import threading
from typing import Any, Callable from typing import Any, Callable
import requests
from common.data_source.config import DocumentSource from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES from common.data_source.google_util.constant import GOOGLE_SCOPES
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
DEFAULT_DEVICE_INTERVAL = 5
def _get_requested_scopes(source: DocumentSource) -> list[str]: def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var.""" """Return the scopes to request, honoring an optional override env var."""
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
return result.get("value") return result.get("value")
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
if "client_id" in credentials:
return credentials["client_id"], credentials.get("client_secret")
for key in ("installed", "web"):
if key in credentials and isinstance(credentials[key], dict):
nested = credentials[key]
if "client_id" not in nested:
break
return nested["client_id"], nested.get("client_secret")
raise ValueError("Provided Google OAuth credentials are missing client_id.")
def start_device_authorization_flow(
credentials: dict[str, Any],
source: DocumentSource,
) -> tuple[dict[str, Any], dict[str, Any]]:
client_id, client_secret = _extract_client_info(credentials)
data = {
"client_id": client_id,
"scope": " ".join(_get_requested_scopes(source)),
}
if client_secret:
data["client_secret"] = client_secret
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
resp.raise_for_status()
payload = resp.json()
state = {
"client_id": client_id,
"client_secret": client_secret,
"device_code": payload.get("device_code"),
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
}
response_data = {
"user_code": payload.get("user_code"),
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
"verification_url_complete": payload.get("verification_url_complete")
or payload.get("verification_uri_complete"),
"expires_in": payload.get("expires_in"),
"interval": state["interval"],
}
return state, response_data
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
data = {
"client_id": state["client_id"],
"device_code": state["device_code"],
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
if state.get("client_secret"):
data["client_secret"] = state["client_secret"]
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
resp.raise_for_status()
return resp.json()
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]: def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens.""" """Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT") preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0 port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs() timeout_secs = _get_oauth_timeout_secs()
timeout_message = ( timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
print("Launching Google OAuth flow. A browser window should open shortly.") print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.") print("If it does not, copy the URL shown in the console into your browser manually.")
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
instructions = [ instructions = [
"Google rejected one or more of the requested OAuth scopes.", "Google rejected one or more of the requested OAuth scopes.",
"Fix options:", "Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes " " 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.", " 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
] ]
raise RuntimeError("\n".join(instructions)) from warning raise RuntimeError("\n".join(instructions)) from warning
raise raise
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
client_config = {"web": credentials["web"]} client_config = {"web": credentials["web"]}
if client_config is None: if client_config is None:
raise ValueError( raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
return _run_local_server_flow(client_config, source) return _run_local_server_flow(client_config, source)

View File

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = [] out_results = []
error_count = 0 error_count = 0
max_errors = 3 max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency) limiter = trio.Semaphore(max_concurrency)

View File

@ -15,27 +15,35 @@
# #
import logging import logging
import re import re
import umap
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture
import trio import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.connection_utils import timeout from common.connection_utils import timeout
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import ( from graphrag.utils import (
get_llm_cache, chat_limiter,
get_embed_cache, get_embed_cache,
get_llm_cache,
set_embed_cache, set_embed_cache,
set_llm_cache, set_llm_cache,
chat_limiter,
) )
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__( def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
): ):
self._max_cluster = max_cluster self._max_cluster = max_cluster
self._llm_model = llm_model self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold self._threshold = threshold
self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20) @timeout(60 * 20)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync( cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) if cached:
) return cached
if response: last_exc = None
return response for attempt in range(3):
response = await trio.to_thread.run_sync( try:
lambda: self._llm_model.chat(system, history, gen_conf) response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
) response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0:
if response.find("**ERROR**") >= 0: raise Exception(response)
raise Exception(response) await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
await trio.to_thread.run_sync( return response
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) except Exception as exc:
) last_exc = exc
return response logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20) @timeout(20)
async def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync( response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
if response is not None: if response is not None:
return response return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters) n_clusters = np.arange(1, max_clusters)
bics = [] bics = []
for n in n_clusters: for n in n_clusters:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.") logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
@timeout(60*20) @timeout(60 * 20)
async def summarize(ck_idx: list[int]): async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int( len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
(self._llm_model.max_length - self._max_token) / len(texts) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
) try:
cluster_content = "\n".join( async with chat_limiter:
[truncate(t, max(1, len_per_chunk)) for t in texts] if task_id and has_canceled(task_id):
) logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
async with chat_limiter: raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id): cnt = await self._chat(
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") "You're a helpful assistant.",
raise TaskCanceledException(f"Task {task_id} was cancelled") [
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
cnt = await self._chat( if task_id and has_canceled(task_id):
"You're a helpful assistant.", logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
[ raise TaskCanceledException(f"Task {task_id} was cancelled")
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id): embds = await self._embedding_encode(cnt)
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") chunks.append((cnt, embds))
raise TaskCanceledException(f"Task {task_id} was cancelled") except TaskCanceledException:
raise
embds = await self._embedding_encode(cnt) except Exception as exc:
chunks.append((cnt, embds)) self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = [] labels = []
while end - start > 1: while end - start > 1:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2: if len(embeddings) == 2:
await summarize([start, start + 1]) await summarize([start, start + 1])
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0]) labels.extend([0, 0])
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
start = end start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx) nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format( assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
len(chunks) - end, n_clusters
)
labels.extend(lbls) labels.extend(lbls)
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end start = end
end = len(chunks) end = len(chunks)

View File

@ -649,6 +649,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = [] res = []
tk_count = 0 tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did): async def generate(chunks, did):
nonlocal tk_count, res nonlocal tk_count, res
raptor = Raptor( raptor = Raptor(
@ -658,6 +660,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"], raptor_config["prompt"],
raptor_config["max_token"], raptor_config["max_token"],
raptor_config["threshold"], raptor_config["threshold"],
max_errors=max_errors,
) )
original_length = len(chunks) original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])