Compare commits

...

2 Commits

Author SHA1 Message Date
e8f1a245a6 Feat:update check_embedding api (#11254)
### What problem does this PR solve?
pr: 
#10854
change:
update check_embedding api

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:25 +08:00
908450509f Feat: add fault-tolerant mechanism to RAPTOR (#11206)
### What problem does this PR solve?

Add fault-tolerant mechanism to RAPTOR.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:07 +08:00
5 changed files with 113 additions and 157 deletions

View File

@ -16,6 +16,7 @@
import json
import logging
import random
import re
from flask import request
from flask_login import login_required, current_user
@ -847,8 +848,13 @@ def check_embedding():
"position_int": full_doc.get("position_int"),
"top_int": full_doc.get("top_int"),
"content_with_weight": full_doc.get("content_with_weight") or "",
"question_kwd": full_doc.get("question_kwd") or []
})
return out
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = request.json
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
@ -861,8 +867,10 @@ def check_embedding():
results, eff_sims = [], []
for ck in samples:
txt = (ck.get("content_with_weight") or "").strip()
if not txt:
title = ck.get("doc_name") or "Title"
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
txt_in = _clean(txt_in)
if not txt_in:
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
continue
@ -871,8 +879,16 @@ def check_embedding():
continue
try:
qv, _ = emb_mdl.encode_queries(txt)
sim = _cos_sim(qv, ck["vector"])
v, _ = emb_mdl.encode([title, txt_in])
sim_content = _cos_sim(v[1], ck["vector"])
title_w = 0.1
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
sim_mix = _cos_sim(qv_mix, ck["vector"])
sim = sim_content
mode = "content_only"
if sim_mix > sim:
sim = sim_mix
mode = "title+content"
except Exception:
return get_error_data_result(message="embedding failure")
@ -894,8 +910,9 @@ def check_embedding():
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
"match_mode": mode,
}
if summary["avg_cos_sim"] > 0.99:
if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})

View File

@ -3,15 +3,9 @@ import os
import threading
from typing import Any, Callable
import requests
from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
DEFAULT_DEVICE_INTERVAL = 5
def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var."""
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
return result.get("value")
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
if "client_id" in credentials:
return credentials["client_id"], credentials.get("client_secret")
for key in ("installed", "web"):
if key in credentials and isinstance(credentials[key], dict):
nested = credentials[key]
if "client_id" not in nested:
break
return nested["client_id"], nested.get("client_secret")
raise ValueError("Provided Google OAuth credentials are missing client_id.")
def start_device_authorization_flow(
credentials: dict[str, Any],
source: DocumentSource,
) -> tuple[dict[str, Any], dict[str, Any]]:
client_id, client_secret = _extract_client_info(credentials)
data = {
"client_id": client_id,
"scope": " ".join(_get_requested_scopes(source)),
}
if client_secret:
data["client_secret"] = client_secret
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
resp.raise_for_status()
payload = resp.json()
state = {
"client_id": client_id,
"client_secret": client_secret,
"device_code": payload.get("device_code"),
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
}
response_data = {
"user_code": payload.get("user_code"),
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
"verification_url_complete": payload.get("verification_url_complete")
or payload.get("verification_uri_complete"),
"expires_in": payload.get("expires_in"),
"interval": state["interval"],
}
return state, response_data
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
data = {
"client_id": state["client_id"],
"device_code": state["device_code"],
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
if state.get("client_secret"):
data["client_secret"] = state["client_secret"]
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
resp.raise_for_status()
return resp.json()
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs()
timeout_message = (
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.")
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
instructions = [
"Google rejected one or more of the requested OAuth scopes.",
"Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
]
raise RuntimeError("\n".join(instructions)) from warning
raise
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
client_config = {"web": credentials["web"]}
if client_config is None:
raise ValueError(
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
return _run_local_server_flow(client_config, source)

View File

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = 3
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency)

View File

@ -15,27 +15,35 @@
#
import logging
import re
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import (
get_llm_cache,
chat_limiter,
get_embed_cache,
get_llm_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
):
self._max_cluster = max_cluster
self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold
self._prompt = prompt
self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20)
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync(
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
)
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
if cached:
return cached
if response:
return response
response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
)
return response
last_exc = None
for attempt in range(3):
try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20)
async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
if response is not None:
return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters)
bics = []
for n in n_clusters:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
@timeout(60*20)
@timeout(60 * 20)
async def summarize(ck_idx: list[int]):
nonlocal chunks
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int(
(self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter:
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = []
while end - start > 1:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2:
await summarize([start, start + 1])
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
labels.extend([0, 0])
layers.append((end, len(chunks)))
start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)
layers.append((end, len(chunks)))
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
start = end
end = len(chunks)

View File

@ -442,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@timeout(60)
@ -465,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
vects = title_w * tts + (1 - title_w) * cnts
else:
vects = cnts
assert len(vects) == len(docs)
vector_size = 0
@ -649,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = []
tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did):
nonlocal tk_count, res
raptor = Raptor(
@ -658,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"],
raptor_config["max_token"],
raptor_config["threshold"],
max_errors=max_errors,
)
original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])