diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 913956ff7..8892bb2f6 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -152,7 +152,6 @@ class EntityResolution(Extractor): ) async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock): - gen_conf = {"temperature": 0.5} pair_txt = [ f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] for index, candidate in enumerate(candidate_resolution_i[1]): @@ -171,7 +170,7 @@ class EntityResolution(Extractor): async with chat_limiter: try: with trio.move_on_after(120) as cancel_scope: - response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) + response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}) if cancel_scope.cancelled_caught: logging.warning("_resolve_candidate._chat timeout, skipping...") return diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index a400d3035..6a292b98c 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -90,11 +90,10 @@ class CommunityReportsExtractor(Extractor): "relation_df": rela_df.to_csv(index_label="id") } text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) - gen_conf = {"temperature": 0.3} async with chat_limiter: try: with trio.move_on_after(80) as cancel_scope: - response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) + response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}) if cancel_scope.cancelled_caught: logging.warning("extract_community_report._chat timeout, skipping...") return diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index 0a963f25b..346a5b95b 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -105,10 +105,9 @@ class GraphExtractor(Extractor): **self._prompt_variables, self._input_text_key: content, } - gen_conf = {"temperature": 0.3} hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)) + response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], {})) token_count += num_tokens_from_string(hint_prompt + response) results = response or "" @@ -118,7 +117,7 @@ class GraphExtractor(Extractor): for i in range(self._max_gleanings): history.append({"role": "user", "content": CONTINUE_PROMPT}) async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf)) + response = await trio.to_thread.run_sync(lambda: self._chat("", history, {})) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) results += response or "" diff --git a/graphrag/general/mind_map_extractor.py b/graphrag/general/mind_map_extractor.py index b4ee6343e..d713eb59f 100644 --- a/graphrag/general/mind_map_extractor.py +++ b/graphrag/general/mind_map_extractor.py @@ -171,9 +171,8 @@ class MindMapExtractor(Extractor): self._input_text_key: text, } text = perform_variable_replacements(self._mind_map_prompt, variables=variables) - gen_conf = {"temperature": 0.5} async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) + response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {})) response = re.sub(r"```[^\n]*", "", response) logging.debug(response) logging.debug(self._todict(markdown_to_json.dictify(response))) diff --git a/graphrag/search.py b/graphrag/search.py index b5c5fe9b2..38ce19712 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -45,7 +45,7 @@ class KGSearch(Dealer): ty2ents = trio.run(lambda: get_entity_type2sampels(idxnms, kb_ids)) hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) - result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5}) + result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) try: keywords_data = json_repair.loads(result) type_keywords = keywords_data.get("answer_type_keywords", []) diff --git a/rag/raptor.py b/rag/raptor.py index 80bf6ca88..961c5ab13 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -107,7 +107,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: ), } ], - {"temperature": 0.3, "max_tokens": self._max_token}, + {"max_tokens": self._max_token}, ) cnt = re.sub( "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 3e76b56bd..9edd6afe6 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -103,6 +103,7 @@ MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDER MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10')) task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) +embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO) kg_limiter = trio.CapacityLimiter(2) WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) @@ -442,7 +443,8 @@ async def embedding(docs, mdl, parser_config=None, callback=None): cnts_ = np.array([]) for i in range(0, len(cnts), EMBEDDING_BATCH_SIZE): - vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + EMBEDDING_BATCH_SIZE]])) + async with embed_limiter: + vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + EMBEDDING_BATCH_SIZE]])) if len(cnts_) == 0: cnts_ = vts else: