Compare commits

..

2 Commits

Author SHA1 Message Date
e8f1a245a6 Feat:update check_embedding api (#11254)
### What problem does this PR solve?
pr: 
#10854
change:
update check_embedding api

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:25 +08:00
908450509f Feat: add fault-tolerant mechanism to RAPTOR (#11206)
### What problem does this PR solve?

Add fault-tolerant mechanism to RAPTOR.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:07 +08:00
5 changed files with 113 additions and 157 deletions

View File

@ -16,6 +16,7 @@
import json import json
import logging import logging
import random import random
import re
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
@ -847,8 +848,13 @@ def check_embedding():
"position_int": full_doc.get("position_int"), "position_int": full_doc.get("position_int"),
"top_int": full_doc.get("top_int"), "top_int": full_doc.get("top_int"),
"content_with_weight": full_doc.get("content_with_weight") or "", "content_with_weight": full_doc.get("content_with_weight") or "",
"question_kwd": full_doc.get("question_kwd") or []
}) })
return out return out
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = request.json req = request.json
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "") embd_id = req.get("embd_id", "")
@ -861,8 +867,10 @@ def check_embedding():
results, eff_sims = [], [] results, eff_sims = [], []
for ck in samples: for ck in samples:
txt = (ck.get("content_with_weight") or "").strip() title = ck.get("doc_name") or "Title"
if not txt: txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
txt_in = _clean(txt_in)
if not txt_in:
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
continue continue
@ -871,8 +879,16 @@ def check_embedding():
continue continue
try: try:
qv, _ = emb_mdl.encode_queries(txt) v, _ = emb_mdl.encode([title, txt_in])
sim = _cos_sim(qv, ck["vector"]) sim_content = _cos_sim(v[1], ck["vector"])
title_w = 0.1
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
sim_mix = _cos_sim(qv_mix, ck["vector"])
sim = sim_content
mode = "content_only"
if sim_mix > sim:
sim = sim_mix
mode = "title+content"
except Exception: except Exception:
return get_error_data_result(message="embedding failure") return get_error_data_result(message="embedding failure")
@ -894,8 +910,9 @@ def check_embedding():
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
"match_mode": mode,
} }
if summary["avg_cos_sim"] > 0.99: if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results}) return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results}) return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})

View File

@ -3,15 +3,9 @@ import os
import threading import threading
from typing import Any, Callable from typing import Any, Callable
import requests
from common.data_source.config import DocumentSource from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES from common.data_source.google_util.constant import GOOGLE_SCOPES
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
DEFAULT_DEVICE_INTERVAL = 5
def _get_requested_scopes(source: DocumentSource) -> list[str]: def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var.""" """Return the scopes to request, honoring an optional override env var."""
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
return result.get("value") return result.get("value")
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
if "client_id" in credentials:
return credentials["client_id"], credentials.get("client_secret")
for key in ("installed", "web"):
if key in credentials and isinstance(credentials[key], dict):
nested = credentials[key]
if "client_id" not in nested:
break
return nested["client_id"], nested.get("client_secret")
raise ValueError("Provided Google OAuth credentials are missing client_id.")
def start_device_authorization_flow(
credentials: dict[str, Any],
source: DocumentSource,
) -> tuple[dict[str, Any], dict[str, Any]]:
client_id, client_secret = _extract_client_info(credentials)
data = {
"client_id": client_id,
"scope": " ".join(_get_requested_scopes(source)),
}
if client_secret:
data["client_secret"] = client_secret
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
resp.raise_for_status()
payload = resp.json()
state = {
"client_id": client_id,
"client_secret": client_secret,
"device_code": payload.get("device_code"),
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
}
response_data = {
"user_code": payload.get("user_code"),
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
"verification_url_complete": payload.get("verification_url_complete")
or payload.get("verification_uri_complete"),
"expires_in": payload.get("expires_in"),
"interval": state["interval"],
}
return state, response_data
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
data = {
"client_id": state["client_id"],
"device_code": state["device_code"],
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
if state.get("client_secret"):
data["client_secret"] = state["client_secret"]
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
resp.raise_for_status()
return resp.json()
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]: def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens.""" """Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT") preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0 port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs() timeout_secs = _get_oauth_timeout_secs()
timeout_message = ( timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
print("Launching Google OAuth flow. A browser window should open shortly.") print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.") print("If it does not, copy the URL shown in the console into your browser manually.")
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
instructions = [ instructions = [
"Google rejected one or more of the requested OAuth scopes.", "Google rejected one or more of the requested OAuth scopes.",
"Fix options:", "Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes " " 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.", " 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
] ]
raise RuntimeError("\n".join(instructions)) from warning raise RuntimeError("\n".join(instructions)) from warning
raise raise
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
client_config = {"web": credentials["web"]} client_config = {"web": credentials["web"]}
if client_config is None: if client_config is None:
raise ValueError( raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
return _run_local_server_flow(client_config, source) return _run_local_server_flow(client_config, source)

View File

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = [] out_results = []
error_count = 0 error_count = 0
max_errors = 3 max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency) limiter = trio.Semaphore(max_concurrency)

View File

@ -15,27 +15,35 @@
# #
import logging import logging
import re import re
import umap
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture
import trio import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.connection_utils import timeout from common.connection_utils import timeout
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import ( from graphrag.utils import (
get_llm_cache, chat_limiter,
get_embed_cache, get_embed_cache,
get_llm_cache,
set_embed_cache, set_embed_cache,
set_llm_cache, set_llm_cache,
chat_limiter,
) )
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__( def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
): ):
self._max_cluster = max_cluster self._max_cluster = max_cluster
self._llm_model = llm_model self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold self._threshold = threshold
self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20) @timeout(60 * 20)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync( cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) if cached:
) return cached
if response: last_exc = None
return response for attempt in range(3):
response = await trio.to_thread.run_sync( try:
lambda: self._llm_model.chat(system, history, gen_conf) response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0: if response.find("**ERROR**") >= 0:
raise Exception(response) raise Exception(response)
await trio.to_thread.run_sync( await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
)
return response return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20) @timeout(20)
async def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync( response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
if response is not None: if response is not None:
return response return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters) n_clusters = np.arange(1, max_clusters)
bics = [] bics = []
for n in n_clusters: for n in n_clusters:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.") logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
@timeout(60*20) @timeout(60 * 20)
async def summarize(ck_idx: list[int]): async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
@ -111,14 +122,10 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int( len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
(self._llm_model.max_length - self._max_token) / len(texts) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
) try:
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter: async with chat_limiter:
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
@ -128,9 +135,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
[ [
{ {
"role": "user", "role": "user",
"content": self._prompt.format( "content": self._prompt.format(cluster_content=cluster_content),
cluster_content=cluster_content
),
} }
], ],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
@ -148,10 +153,19 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
embds = await self._embedding_encode(cnt) embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds)) chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = [] labels = []
while end - start > 1: while end - start > 1:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2: if len(embeddings) == 2:
await summarize([start, start + 1]) await summarize([start, start + 1])
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0]) labels.extend([0, 0])
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
start = end start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx) nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format( assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
len(chunks) - end, n_clusters
)
labels.extend(lbls) labels.extend(lbls)
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end start = end
end = len(chunks) end = len(chunks)

View File

@ -442,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0 tk_count = 0
if len(tts) == len(cnts): if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0) tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c tk_count += c
@timeout(60) @timeout(60)
@ -465,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
if not filename_embd_weight: if not filename_embd_weight:
filename_embd_weight = 0.1 filename_embd_weight = 0.1
title_w = float(filename_embd_weight) title_w = float(filename_embd_weight)
vects = (title_w * tts + (1 - title_w) * if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
cnts) if len(tts) == len(cnts) else cnts vects = title_w * tts + (1 - title_w) * cnts
else:
vects = cnts
assert len(vects) == len(docs) assert len(vects) == len(docs)
vector_size = 0 vector_size = 0
@ -649,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = [] res = []
tk_count = 0 tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did): async def generate(chunks, did):
nonlocal tk_count, res nonlocal tk_count, res
raptor = Raptor( raptor = Raptor(
@ -658,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"], raptor_config["prompt"],
raptor_config["max_token"], raptor_config["max_token"],
raptor_config["threshold"], raptor_config["threshold"],
max_errors=max_errors,
) )
original_length = len(chunks) original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])