From 2114e966d8296cccb528a6f24cc13e8c1312e93f Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 15 Aug 2025 10:05:01 +0800 Subject: [PATCH] Feat: add citation option to agent and enlarge the timeouts. (#9484) ### What problem does this PR solve? #9422 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/component/agent_with_tools.py | 2 +- agent/component/llm.py | 2 +- graphrag/entity_resolution.py | 4 +-- .../general/community_reports_extractor.py | 2 +- graphrag/general/index.py | 31 ++++++++++--------- rag/svr/task_executor.py | 2 +- 6 files changed, 22 insertions(+), 21 deletions(-) diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index f656af0e3..f0369c0a6 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -244,7 +244,7 @@ class Agent(LLM, ToolBase): def complete(): nonlocal hist - need2cite = self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 + need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 cited = False if hist[0]["role"] == "system" and need2cite: if len(hist) < 7: diff --git a/agent/component/llm.py b/agent/component/llm.py index 5e10220a4..963a7e9f0 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -145,7 +145,7 @@ class LLM(ComponentBase): prompt = self.string_format(prompt, args) for m in msg: m["content"] = self.string_format(m["content"], args) - if self._canvas.get_reference()["chunks"]: + if self._param.cite and self._canvas.get_reference()["chunks"]: prompt += citation_prompt() return prompt, msg diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 8892bb2f6..c324492fe 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -106,7 +106,7 @@ class EntityResolution(Extractor): nonlocal remain_candidates_to_resolve, callback async with semaphore: try: - with trio.move_on_after(180) as cancel_scope: + with trio.move_on_after(280) as cancel_scope: await self._resolve_candidate(candidate_batch, result_set, result_lock) remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") @@ -169,7 +169,7 @@ class EntityResolution(Extractor): logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: try: - with trio.move_on_after(120) as cancel_scope: + with trio.move_on_after(240) as cancel_scope: response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}) if cancel_scope.cancelled_caught: logging.warning("_resolve_candidate._chat timeout, skipping...") diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 6a292b98c..32b0d6a0f 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -92,7 +92,7 @@ class CommunityReportsExtractor(Extractor): text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) async with chat_limiter: try: - with trio.move_on_after(80) as cancel_scope: + with trio.move_on_after(180) as cancel_scope: response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}) if cancel_scope.cancelled_caught: logging.warning("extract_community_report._chat timeout, skipping...") diff --git a/graphrag/general/index.py b/graphrag/general/index.py index ca88a292b..9adfab076 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -57,20 +57,22 @@ async def run_graphrag( ): chunks.append(d["content_with_weight"]) - subgraph = await generate_subgraph( - LightKGExt - if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" - else GeneralKGExt, - tenant_id, - kb_id, - doc_id, - chunks, - language, - row["kb_parser_config"]["graphrag"].get("entity_types", []), - chat_model, - embedding_model, - callback, - ) + with trio.fail_after(len(chunks)*60): + subgraph = await generate_subgraph( + LightKGExt + if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" + else GeneralKGExt, + tenant_id, + kb_id, + doc_id, + chunks, + language, + row["kb_parser_config"]["graphrag"].get("entity_types", []), + chat_model, + embedding_model, + callback, + ) + if not subgraph: return @@ -125,7 +127,6 @@ async def run_graphrag( return -@timeout(60*60, 1) async def generate_subgraph( extractor: Extractor, tenant_id: str, diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 649c7e95e..4477a825f 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -520,7 +520,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): return res, tk_count -@timeout(60*60, 1) +@timeout(60*60*2, 1) async def do_handle_task(task): task_id = task["id"] task_from_page = task["from_page"]