diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 60792d381..97b135775 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -94,25 +94,52 @@ class EntityResolution(Extractor): candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)] num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()]) callback(msg=f"Identified {num_candidates} candidate pairs") + remain_candidates_to_resolve = num_candidates resolution_result = set() + resolution_result_lock = trio.Lock() resolution_batch_size = 100 + max_concurrent_tasks = 5 + semaphore = trio.Semaphore(max_concurrent_tasks) + + async def limited_resolve_candidate(candidate_batch, result_set, result_lock): + nonlocal remain_candidates_to_resolve, callback + async with semaphore: + try: + with trio.move_on_after(180) as cancel_scope: + await self._resolve_candidate(candidate_batch, result_set, result_lock) + remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) + callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") + if cancel_scope.cancelled_caught: + logging.warning(f"Timeout resolving {candidate_batch}, skipping...") + remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) + callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ") + except Exception as e: + logging.error(f"Error resolving candidate batch: {e}") + + async with trio.open_nursery() as nursery: for candidate_resolution_i in candidate_resolution.items(): if not candidate_resolution_i[1]: continue for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] - nursery.start_soon(self._resolve_candidate, candidate_batch, resolution_result) + nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock) + callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") change = GraphChange() connect_graph = nx.Graph() connect_graph.add_edges_from(resolution_result) + + async def limited_merge_nodes(graph, nodes, change): + async with semaphore: + await self._merge_graph_nodes(graph, nodes, change) + async with trio.open_nursery() as nursery: for sub_connect_graph in nx.connected_components(connect_graph): merging_nodes = list(sub_connect_graph) - nursery.start_soon(self._merge_graph_nodes, graph, merging_nodes, change) + nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change) # Update pagerank pr = nx.pagerank(graph) @@ -124,7 +151,7 @@ class EntityResolution(Extractor): change=change, ) - async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]): + async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock): gen_conf = {"temperature": 0.5} pair_txt = [ f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] @@ -142,7 +169,16 @@ class EntityResolution(Extractor): text = perform_variable_replacements(self._resolution_prompt, variables=variables) logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) + try: + with trio.move_on_after(120) as cancel_scope: + response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) + if cancel_scope.cancelled_caught: + logging.warning("_resolve_candidate._chat timeout, skipping...") + return + except Exception as e: + logging.error(f"_resolve_candidate._chat failed: {e}") + return + logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}") result = self._process_results(len(candidate_resolution_i[1]), response, self.prompt_variables.get(self._record_delimiter_key, @@ -151,8 +187,9 @@ class EntityResolution(Extractor): DEFAULT_ENTITY_INDEX_DELIMITER), self.prompt_variables.get(self._resolution_result_delimiter_key, DEFAULT_RESOLUTION_RESULT_DELIMITER)) - for result_i in result: - resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) + async with resolution_result_lock: + for result_i in result: + resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) def _process_results( self, @@ -185,6 +222,7 @@ class EntityResolution(Extractor): if is_english(a) and is_english(b): if editdistance.eval(a, b) <= min(len(a), len(b)) // 2: return True + return False if len(set(a) & set(b)) > 1: return True diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 14966af02..4d8b33bfd 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -89,7 +89,15 @@ class CommunityReportsExtractor(Extractor): text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) gen_conf = {"temperature": 0.3} async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) + try: + with trio.move_on_after(120) as cancel_scope: + response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) + if cancel_scope.cancelled_caught: + logging.warning("extract_community_report._chat timeout, skipping...") + return + except Exception as e: + logging.error(f"extract_community_report._chat failed: {e}") + return token_count += num_tokens_from_string(text + response) response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"[^\}]*$", "", response) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 75a087db9..10fb62b41 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -21,6 +21,8 @@ import sys import threading import time +from valkey import RedisError + from api.utils.log_utils import initRootLogger, get_project_base_directory from graphrag.general.index import run_graphrag from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache @@ -187,18 +189,44 @@ async def collect(): global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS global UNACKED_ITERATOR svr_queue_names = get_svr_queue_names() + redis_msg = None + try: if not UNACKED_ITERATOR: - UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) - try: - redis_msg = next(UNACKED_ITERATOR) - except StopIteration: + UNACKED_ITERATOR = None + logging.debug("Rebuilding UNACKED_ITERATOR due to it is None") + try: + UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) + logging.debug("UNACKED_ITERATOR rebuilt successfully") + except RedisError as e: + UNACKED_ITERATOR = None + logging.warning(f"Failed to rebuild UNACKED_ITERATOR: {e}") + + if UNACKED_ITERATOR: + try: + redis_msg = next(UNACKED_ITERATOR) + except StopIteration: + UNACKED_ITERATOR = None + logging.debug("UNACKED_ITERATOR exhausted, clearing") + + except Exception as e: + UNACKED_ITERATOR = None + logging.warning(f"UNACKED_ITERATOR raised exception: {e}") + + if not redis_msg: for svr_queue_name in svr_queue_names: - redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) - if redis_msg: - break - except Exception: - logging.exception("collect got exception") + try: + redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) + if redis_msg: + break + except RedisError as e: + logging.warning(f"queue_consumer failed for {svr_queue_name}: {e}") + continue + + except Exception as e: + logging.exception(f"collect task encountered unexpected exception: {e}") + UNACKED_ITERATOR = None + await trio.sleep(1) return None, None if not redis_msg: