diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 30f5042ed..1c0ea19b6 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -302,6 +302,20 @@ "model_type": "chat", "is_tools": true }, + { + "llm_name": "qwen-plus-2025-07-28", + "tags": "LLM,CHAT,132k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "qwen-plus-2025-07-14", + "tags": "LLM,CHAT,132k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "qwq-plus-latest", "tags": "LLM,CHAT,132k", @@ -309,6 +323,20 @@ "model_type": "chat", "is_tools": true }, + { + "llm_name": "qwen-flash", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "qwen-flash-2025-07-28", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "qwen3-coder-480b-a35b-instruct", "tags": "LLM,CHAT,256k", diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 8a8655308..61d89e27c 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,17 +14,28 @@ # limitations under the License. # import logging +import os import re -from collections import defaultdict, Counter +from collections import Counter, defaultdict from copy import deepcopy from typing import Callable -import trio + import networkx as nx +import trio from api.utils.api_utils import timeout from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT -from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ - handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange +from graphrag.utils import ( + GraphChange, + chat_limiter, + flat_uniq_list, + get_from_to, + get_llm_cache, + handle_single_entity_extraction, + handle_single_relationship_extraction, + set_llm_cache, + split_string_by_multi_markers, +) from rag.llm.chat_model import Base as CompletionLLM from rag.prompts import message_fit_in from rag.utils import truncate @@ -32,6 +43,7 @@ from rag.utils import truncate GRAPH_FIELD_SEP = "" DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"] ENTITY_EXTRACTION_MAX_GLEANINGS = 2 +MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = int(os.environ.get("MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK", 10)) class Extractor: @@ -47,7 +59,7 @@ class Extractor: self._language = language self._entity_types = entity_types or DEFAULT_ENTITY_TYPES - @timeout(60*20) + @timeout(60 * 20) def _chat(self, system, history, gen_conf={}): hist = deepcopy(history) conf = deepcopy(gen_conf) @@ -55,6 +67,7 @@ class Extractor: if response: return response _, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92)) + response = "" for attempt in range(3): try: response = self._llm.chat(system_msg[0]["content"], hist, conf) @@ -74,38 +87,37 @@ class Extractor: maybe_edges = defaultdict(list) ent_types = [t.lower() for t in self._entity_types] for record in records: - record_attributes = split_string_by_multi_markers( - record, [tuple_delimiter] - ) + record_attributes = split_string_by_multi_markers(record, [tuple_delimiter]) - if_entities = handle_single_entity_extraction( - record_attributes, chunk_key - ) + if_entities = handle_single_entity_extraction(record_attributes, chunk_key) if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types: maybe_nodes[if_entities["entity_name"]].append(if_entities) continue - if_relation = handle_single_relationship_extraction( - record_attributes, chunk_key - ) + if_relation = handle_single_relationship_extraction(record_attributes, chunk_key) if if_relation is not None: - maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( - if_relation - ) + maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation) return dict(maybe_nodes), dict(maybe_edges) - async def __call__( - self, doc_id: str, chunks: list[str], - callback: Callable | None = None - ): - + async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None): self.callback = callback start_ts = trio.current_time() - out_results = [] - async with trio.open_nursery() as nursery: - for i, ck in enumerate(chunks): - ck = truncate(ck, int(self._llm.max_length*0.8)) - nursery.start_soon(self._process_single_content, (doc_id, ck), i, len(chunks), out_results) + + async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK): + out_results = [] + limiter = trio.Semaphore(max_concurrency) + + async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int): + async with limiter: + await self._process_single_content(chunk_key_dp, idx, total, out_results) + + async with trio.open_nursery() as nursery: + for i, ck in enumerate(chunks): + nursery.start_soon(worker, (doc_id, ck), i, len(chunks)) + + return out_results + + out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK) maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) @@ -118,7 +130,7 @@ class Extractor: sum_token_count += token_count now = trio.current_time() if callback: - callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.") + callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.") start_ts = now logging.info("Entities merging...") all_entities_data = [] @@ -127,7 +139,7 @@ class Extractor: nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data) now = trio.current_time() if callback: - callback(msg = f"Entities merging done, {now-start_ts:.2f}s.") + callback(msg=f"Entities merging done, {now - start_ts:.2f}s.") start_ts = now logging.info("Relationships merging...") @@ -137,12 +149,10 @@ class Extractor: nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data) now = trio.current_time() if callback: - callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.") + callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.") if not len(all_entities_data) and not len(all_relationships_data): - logging.warning( - "Didn't extract any entities and relationships, maybe your LLM is not working" - ) + logging.warning("Didn't extract any entities and relationships, maybe your LLM is not working") if not len(all_entities_data): logging.warning("Didn't extract any entities") @@ -155,15 +165,11 @@ class Extractor: if not entities: return entity_type = sorted( - Counter( - [dp["entity_type"] for dp in entities] - ).items(), + Counter([dp["entity_type"] for dp in entities]).items(), key=lambda x: x[1], reverse=True, )[0][0] - description = GRAPH_FIELD_SEP.join( - sorted(set([dp["description"] for dp in entities])) - ) + description = GRAPH_FIELD_SEP.join(sorted(set([dp["description"] for dp in entities]))) already_source_ids = flat_uniq_list(entities, "source_id") description = await self._handle_entity_relation_summary(entity_name, description) node_data = dict( @@ -174,13 +180,7 @@ class Extractor: node_data["entity_name"] = entity_name all_relationships_data.append(node_data) - async def _merge_edges( - self, - src_id: str, - tgt_id: str, - edges_data: list[dict], - all_relationships_data=None - ): + async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None): if not edges_data: return weight = sum([edge["weight"] for edge in edges_data]) @@ -188,14 +188,7 @@ class Extractor: description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description) keywords = flat_uniq_list(edges_data, "keywords") source_id = flat_uniq_list(edges_data, "source_id") - edge_data = dict( - src_id=src_id, - tgt_id=tgt_id, - description=description, - keywords=keywords, - weight=weight, - source_id=source_id - ) + edge_data = dict(src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, weight=weight, source_id=source_id) all_relationships_data.append(edge_data) async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange): @@ -231,14 +224,10 @@ class Extractor: node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"]) graph.nodes[nodes[0]].update(node0_attrs) - async def _handle_entity_relation_summary( - self, - entity_or_relation_name: str, - description: str - ) -> str: + async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str) -> str: summary_max_tokens = 512 use_description = truncate(description, summary_max_tokens) - description_list=use_description.split(GRAPH_FIELD_SEP), + description_list = (use_description.split(GRAPH_FIELD_SEP),) if len(description_list) <= 12: return use_description prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT @@ -250,5 +239,5 @@ class Extractor: use_prompt = prompt_template.format(**context_base) logging.info(f"Trigger summary: {entity_or_relation_name}") async with chat_limiter: - summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}])) + summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}]) return summary diff --git a/graphrag/general/index.py b/graphrag/general/index.py index e5150c54a..9e80309f2 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,25 +23,24 @@ import trio from api import settings from api.utils import get_uuid from api.utils.api_utils import timeout -from graphrag.light.graph_extractor import GraphExtractor as LightKGExt -from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt -from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.entity_resolution import EntityResolution +from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.extractor import Extractor +from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt +from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.utils import ( - graph_merge, - get_graph, - set_graph, + GraphChange, chunk_id, does_graph_contains, + get_graph, + graph_merge, + set_graph, tidy_graph, - GraphChange, ) from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import RedisDistributedLock - async def run_graphrag( row: dict, language, @@ -51,20 +50,16 @@ async def run_graphrag( embedding_model, callback, ): - enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] - for d in settings.retrievaler.chunk_list( - doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"] - ): + for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]): chunks.append(d["content_with_weight"]) - with trio.fail_after(max(120, len(chunks)*60*10) if enable_timeout_assertion else 10000000000): + with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): subgraph = await generate_subgraph( - LightKGExt - if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" - else GeneralKGExt, + LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, tenant_id, kb_id, doc_id, @@ -177,9 +172,7 @@ async def generate_subgraph( subgraph.graph["source_id"] = [doc_id] chunk = { - "content_with_weight": json.dumps( - nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False - ), + "content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False), "knowledge_graph_kwd": "subgraph", "kb_id": kb_id, "source_id": [doc_id], @@ -187,22 +180,14 @@ async def generate_subgraph( "removed_kwd": "N", } cid = chunk_id(chunk) - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( - {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id - ) - ) - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.insert( - [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id - ) - ) + await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) now = trio.current_time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph -@timeout(60*3) +@timeout(60 * 3) async def merge_subgraph( tenant_id: str, kb_id: str, @@ -228,13 +213,11 @@ async def merge_subgraph( await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback) now = trio.current_time() - callback( - msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds." - ) + callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.") return new_graph -@timeout(60*30, 1) +@timeout(60 * 30, 1) async def resolve_entities( graph, subgraph_nodes: set[str], @@ -260,7 +243,7 @@ async def resolve_entities( callback(msg=f"Graph resolution done in {now - start:.2f}s.") -@timeout(60*30, 1) +@timeout(60 * 30, 1) async def extract_community( graph, tenant_id: str, @@ -280,9 +263,7 @@ async def extract_community( doc_ids = graph.graph["source_id"] now = trio.current_time() - callback( - msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s." - ) + callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.") start = now chunks = [] for stru, rep in zip(community_structure, community_reports): @@ -295,9 +276,7 @@ async def extract_community( "docnm_kwd": stru["title"], "title_tks": rag_tokenizer.tokenize(stru["title"]), "content_with_weight": json.dumps(obj, ensure_ascii=False), - "content_ltks": rag_tokenizer.tokenize( - obj["report"] + " " + obj["evidences"] - ), + "content_ltks": rag_tokenizer.tokenize(obj["report"] + " " + obj["evidences"]), "knowledge_graph_kwd": "community_report", "weight_flt": stru["weight"], "entities_kwd": stru["entities"], @@ -306,9 +285,7 @@ async def extract_community( "source_id": list(doc_ids), "available_int": 0, } - chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize( - chunk["content_ltks"] - ) + chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunks.append(chunk) await trio.to_thread.run_sync( @@ -320,13 +297,11 @@ async def extract_community( ) es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) now = trio.current_time() - callback( - msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s." - ) + callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.") return community_structure, community_reports diff --git a/graphrag/light/__init__.py b/graphrag/light/__init__.py index e69de29bb..177b91dd0 100644 --- a/graphrag/light/__init__.py +++ b/graphrag/light/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/graphrag/light/graph_extractor.py b/graphrag/light/graph_extractor.py index 9c869b16d..474d47597 100644 --- a/graphrag/light/graph_extractor.py +++ b/graphrag/light/graph_extractor.py @@ -4,17 +4,21 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ + +import logging import re -from typing import Any from dataclasses import dataclass -from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS -from graphrag.light.graph_prompt import PROMPTS -from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter -from rag.llm.chat_model import Base as CompletionLLM +from typing import Any + import networkx as nx -from rag.utils import num_tokens_from_string import trio +from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor +from graphrag.light.graph_prompt import PROMPTS +from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split_string_by_multi_markers +from rag.llm.chat_model import Base as CompletionLLM +from rag.utils import num_tokens_from_string + @dataclass class GraphExtractionResult: @@ -25,7 +29,6 @@ class GraphExtractionResult: class GraphExtractor(Extractor): - _max_gleanings: int def __init__( @@ -38,15 +41,9 @@ class GraphExtractor(Extractor): ): super().__init__(llm_invoker, language, entity_types) """Init method definition.""" - self._max_gleanings = ( - max_gleanings - if max_gleanings is not None - else ENTITY_EXTRACTION_MAX_GLEANINGS - ) + self._max_gleanings = max_gleanings if max_gleanings is not None else ENTITY_EXTRACTION_MAX_GLEANINGS self._example_number = example_number - examples = "\n".join( - PROMPTS["entity_extraction_examples"][: int(self._example_number)] - ) + examples = "\n".join(PROMPTS["entity_extraction_examples"][: int(self._example_number)]) example_context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], @@ -68,45 +65,52 @@ class GraphExtractor(Extractor): language=self._language, ) - self._continue_prompt = PROMPTS["entiti_continue_extraction"] - self._if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] + self._continue_prompt = PROMPTS["entity_continue_extraction"].format(**self._context_base) + self._if_loop_prompt = PROMPTS["entity_if_loop_extraction"] - self._left_token_count = llm_invoker.max_length - num_tokens_from_string( - self._entity_extract_prompt.format( - **self._context_base, input_text="{input_text}" - ).format(**self._context_base, input_text="") - ) + self._left_token_count = llm_invoker.max_length - num_tokens_from_string(self._entity_extract_prompt.format(**self._context_base, input_text="")) self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count) async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results): token_count = 0 chunk_key = chunk_key_dp[0] content = chunk_key_dp[1] - hint_prompt = self._entity_extract_prompt.format( - **self._context_base, input_text="{input_text}" - ).format(**self._context_base, input_text=content) + hint_prompt = self._entity_extract_prompt.format(**self._context_base, input_text=content) gen_conf = {} + final_result = "" + glean_result = "" + if_loop_result = "" + history = [] + logging.info(f"Start processing for {chunk_key}: {content[:25]}...") + if self.callback: + self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...") async with chat_limiter: - final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)) + final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf) token_count += num_tokens_from_string(hint_prompt + final_result) - history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt) + history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt) for now_glean_index in range(self._max_gleanings): async with chat_limiter: - glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) - history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}]) + # glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) + glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf) + history.extend([{"role": "assistant", "content": glean_result}]) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) final_result += glean_result if now_glean_index == self._max_gleanings - 1: break + history.extend([{"role": "user", "content": self._if_loop_prompt}]) async with chat_limiter: - if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf)) + if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break + history.extend([{"role": "assistant", "content": if_loop_result}, {"role": "user", "content": self._continue_prompt}]) + logging.info(f"Completed processing for {chunk_key}: {content[:25]}... after {now_glean_index} gleanings, {token_count} tokens.") + if self.callback: + self.callback(msg=f"Completed processing for {chunk_key}: {content[:25]}... after {now_glean_index} gleanings, {token_count} tokens.") records = split_string_by_multi_markers( final_result, [self._context_base["record_delimiter"], self._context_base["completion_delimiter"]], @@ -121,4 +125,7 @@ class GraphExtractor(Extractor): maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"]) out_results.append((maybe_nodes, maybe_edges, token_count)) if self.callback: - self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.") + self.callback( + 0.5 + 0.1 * len(out_results) / num_chunks, + msg=f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.", + ) diff --git a/graphrag/light/graph_prompt.py b/graphrag/light/graph_prompt.py index a3bf8c44c..865937afb 100644 --- a/graphrag/light/graph_prompt.py +++ b/graphrag/light/graph_prompt.py @@ -4,26 +4,28 @@ Reference: - [LightRAG](https://github.com/HKUDS/LightRAG/blob/main/lightrag/prompt.py) """ +from typing import Any -PROMPTS = {} +PROMPTS: dict[str, Any] = {} PROMPTS["DEFAULT_LANGUAGE"] = "English" PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" -PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"] -PROMPTS["entity_extraction"] = """-Goal- +PROMPTS["DEFAULT_USER_PROMPT"] = "n/a" + +PROMPTS["entity_extraction"] = """---Goal--- Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. Use {language} as output language. --Steps- +---Steps--- 1. Identify all entities. For each identified entity, extract the following information: - entity_name: Name of the entity, use same language as input text. If English, capitalized the name. - entity_type: One of the following types: [{entity_types}] -- entity_description: Comprehensive description of the entity's attributes and activities +- entity_description: Provide a comprehensive description of the entity's attributes and activities *based solely on the information present in the input text*. **Do not infer or hallucinate information not explicitly stated.** If the text provides insufficient information to create a comprehensive description, state "Description not available in text." Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. @@ -43,31 +45,34 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter}` + - For a Knowledge Graph Relationship: `[KG] - ` + - For a Document Chunk: `[DC] ` + +---USER CONTEXT--- +- Additional user prompt: {user_prompt} + + +Response:""" + +PROMPTS["keywords_extraction"] = """---Role--- +You are an expert keyword extractor, specializing in analyzing user queries for a Retrieval-Augmented Generation (RAG) system. Your purpose is to identify both high-level and low-level keywords in the user's query that will be used for effective document retrieval. + +---Goal--- +Given a user query, your task is to extract two distinct types of keywords: +1. **high_level_keywords**: for overarching concepts or themes, capturing user's core intent, the subject area, or the type of question being asked. +2. **low_level_keywords**: for specific entities or details, identifying the specific entities, proper nouns, technical jargon, product names, or concrete items. + +---Instructions & Constraints--- +1. **Output Format**: Your output MUST be a valid JSON object and nothing else. Do not include any explanatory text, markdown code fences (like ```json), or any other text before or after the JSON. It will be parsed directly by a JSON parser. +2. **Source of Truth**: All keywords must be explicitly derived from the user query, with both high-level and low-level keyword categories required to contain content. +3. **Concise & Meaningful**: Keywords should be concise words or meaningful phrases. Prioritize multi-word phrases when they represent a single concept. For example, from "latest financial report of Apple Inc.", you should extract "latest financial report" and "Apple Inc." rather than "latest", "financial", "report", and "Apple". +4. **Handle Edge Cases**: For queries that are too simple, vague, or nonsensical (e.g., "hello", "ok", "asdfghjkl"), you must return a JSON object with empty lists for both keyword types. + +---Examples--- +{examples} + +---Real Data--- +User Query: {query} + +---Output--- +""" + +PROMPTS["keywords_extraction_examples"] = [ + """Example 1: + +Query: "How does international trade influence global economic stability?" + +Output: +{ + "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"], + "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] +} + +""", + """Example 2: + +Query: "What are the environmental consequences of deforestation on biodiversity?" + +Output: +{ + "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], + "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] +} + +""", + """Example 3: + +Query: "What is the role of education in reducing poverty?" + +Output: +{ + "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], + "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] +} + +""", +] PROMPTS["naive_rag_response"] = """---Role--- -You are a helpful assistant responding to questions about documents provided. - +You are a helpful assistant responding to user query about Document Chunks provided provided in JSON format below. ---Goal--- -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. -If you don't know the answer, just say so. Do not make anything up. -Do not include information where the supporting evidence for it is not provided. +Generate a concise response based on Document Chunks and follow Response Rules, considering both the conversation history and the current query. Summarize all information in the provided Document Chunks, and incorporating general knowledge relevant to the Document Chunks. Do not include information not provided by Document Chunks. -When handling content with timestamps: -1. Each piece of content has a "created_at" timestamp indicating when we acquired this knowledge -2. When encountering conflicting information, consider both the content and the timestamp -3. Don't automatically prefer the most recent content - use judgment based on the context -4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps - ----Target response length and format--- - -{response_type} - ----Documents--- +---Conversation History--- +{history} +---Document Chunks(DC)--- {content_data} -Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. -""" - -PROMPTS[ - "similarity_check" -] = """Please analyze the similarity between these two questions: - -Question 1: {original_prompt} -Question 2: {cached_prompt} - -Please evaluate the following two points and provide a similarity score between 0 and 1 directly: -1. Whether these two questions are semantically similar -2. Whether the answer to Question 2 can be used to answer Question 1 -Similarity score criteria: -0: Completely unrelated or answer cannot be reused, including but not limited to: - - The questions have different topics - - The locations mentioned in the questions are different - - The times mentioned in the questions are different - - The specific individuals mentioned in the questions are different - - The specific events mentioned in the questions are different - - The background information in the questions is different - - The key conditions in the questions are different -1: Identical and answer can be directly reused -0.5: Partially related and answer needs modification to be used -Return only a number between 0-1, without any additional content. -""" - -PROMPTS["mix_rag_response"] = """---Role--- - -You are a professional assistant responsible for answering questions based on knowledge graph and textual information. Please respond in the same language as the user's question. - ----Goal--- - -Generate a concise response that summarizes relevant points from the provided information. If you don't know the answer, just say so. Do not make anything up or include information where the supporting evidence is not provided. - -When handling information with timestamps: -1. Each piece of information (both relationships and content) has a "created_at" timestamp indicating when we acquired this knowledge -2. When encountering conflicting information, consider both the content/relationship and the timestamp -3. Don't automatically prefer the most recent information - use judgment based on the context -4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps - ----Data Sources--- - -1. Knowledge Graph Data: -{kg_context} - -2. Vector Data: -{vector_context} - ----Response Requirements--- +---RESPONSE GUIDELINES--- +**1. Content & Adherence:** +- Strictly adhere to the provided context from the Knowledge Base. Do not invent, assume, or include any information not present in the source data. +- If the answer cannot be found in the provided context, state that you do not have enough information to answer. +- Ensure the response maintains continuity with the conversation history. +**2. Formatting & Language:** +- Format the response using markdown with appropriate section headings. +- The response language must match the user's question language. - Target format and length: {response_type} -- Use markdown formatting with appropriate section headings -- Aim to keep content around 3 paragraphs for conciseness -- Each paragraph should be under a relevant section heading -- Each section should focus on one main point or aspect of the answer -- Use clear and descriptive section titles that reflect the content -- List up to 5 most important reference sources at the end under "References", clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (VD) - Format: [KG/VD] Source content -Add sections and commentary to the response as appropriate for the length and format. If the provided information is insufficient to answer the question, clearly state that you don't know or cannot provide an answer in the same language as the user's question.""" +**3. Citations / References:** +- At the end of the response, under a "References" section, cite a maximum of 5 most relevant sources used. +- Use the following formats for citations: `[DC] ` + +---USER CONTEXT--- +- Additional user prompt: {user_prompt} + + +Response:""" diff --git a/graphrag/utils.py b/graphrag/utils.py index fbe391f8f..6b80d7fe8 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -6,27 +6,27 @@ Reference: - [LightRag](https://github.com/HKUDS/LightRAG) """ +import dataclasses import html import json import logging +import os import re import time from collections import defaultdict from hashlib import md5 -from typing import Any, Callable -import os -import trio -from typing import Set, Tuple +from typing import Any, Callable, Set, Tuple + import networkx as nx import numpy as np +import trio import xxhash from networkx.readwrite import json_graph -import dataclasses -from api.utils.api_utils import timeout from api import settings from api.utils import get_uuid -from rag.nlp import search, rag_tokenizer +from api.utils.api_utils import timeout +from rag.nlp import rag_tokenizer, search from rag.utils.doc_store_conn import OrderByExpr from rag.utils.redis_conn import REDIS_CONN @@ -34,7 +34,8 @@ GRAPH_FIELD_SEP = "" ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] -chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10))) +chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) + @dataclasses.dataclass class GraphChange: @@ -43,9 +44,8 @@ class GraphChange: removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set) added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set) -def perform_variable_replacements( - input: str, history: list[dict] | None = None, variables: dict | None = None -) -> str: + +def perform_variable_replacements(input: str, history: list[dict] | None = None, variables: dict | None = None) -> str: """Perform variable replacements on the input string and in a chat log.""" if history is None: history = [] @@ -78,9 +78,7 @@ def clean_str(input: Any) -> str: return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result) -def dict_has_keys_with_types( - data: dict, expected_fields: list[tuple[str, type]] -) -> bool: +def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]]) -> bool: """Return True if the given dictionary has the given keys with the given types.""" for field, field_type in expected_fields: if field not in data: @@ -102,7 +100,7 @@ def get_llm_cache(llmnm, txt, history, genconf): k = hasher.hexdigest() bin = REDIS_CONN.get(k) if not bin: - return + return None return bin @@ -114,7 +112,7 @@ def set_llm_cache(llmnm, txt, v, history, genconf): hasher.update(str(genconf).encode("utf-8")) k = hasher.hexdigest() - REDIS_CONN.set(k, v.encode("utf-8"), 24*3600) + REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600) def get_embed_cache(llmnm, txt): @@ -136,7 +134,7 @@ def set_embed_cache(llmnm, txt, arr): k = hasher.hexdigest() arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr) - REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600) + REDIS_CONN.set(k, arr.encode("utf-8"), 24 * 3600) def get_tags_from_cache(kb_ids): @@ -162,6 +160,7 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): """ Ensure all nodes and edges in the graph have some essential attribute. """ + def is_valid_item(node_attrs: dict) -> bool: valid_node = True for attr in ["description", "source_id"]: @@ -169,6 +168,7 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): valid_node = False break return valid_node + if check_attribute: purged_nodes = [] for node, node_attrs in graph.nodes(data=True): @@ -267,9 +267,7 @@ def handle_single_relationship_extraction(record_attributes: list[str], chunk_ke edge_keywords = clean_str(record_attributes[4]) edge_source_id = chunk_key - weight = ( - float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 - ) + weight = float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 pair = sorted([source.upper(), target.upper()]) return dict( src_id=pair[0], @@ -284,9 +282,7 @@ def handle_single_relationship_extraction(record_attributes: list[str], chunk_ke def pack_user_ass_to_openai_messages(*args: str): roles = ["user", "assistant"] - return [ - {"role": roles[i % 2], "content": content} for i, content in enumerate(args) - ] + return [{"role": roles[i % 2], "content": content} for i, content in enumerate(args)] def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: @@ -307,7 +303,7 @@ def chunk_id(chunk): async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): global chat_limiter - enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") chunk = { "id": get_uuid(), "important_kwd": [ent_name], @@ -319,7 +315,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): "content_ltks": rag_tokenizer.tokenize(meta["description"]), "source_id": meta["source_id"], "kb_id": kb_id, - "available_int": 0 + "available_int": 0, } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) ebd = get_embed_cache(embd_mdl.llm_name, ent_name) @@ -343,13 +339,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): to_ent_name = [to_ent_name] ents.extend(to_ent_name) ents = list(set(ents)) - conds = { - "fields": ["content_with_weight"], - "size": size, - "from_entity_kwd": ents, - "to_entity_kwd": ents, - "knowledge_graph_kwd": ["relation"] - } + conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]} res = [] es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) for id in es_res.ids: @@ -363,7 +353,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): - enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") chunk = { "id": get_uuid(), "from_entity_kwd": from_ent_name, @@ -375,7 +365,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, "source_id": meta["source_id"], "weight_int": int(meta["weight"]), "kb_id": kb_id, - "available_int": 0 + "available_int": 0, } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) txt = f"{from_ent_name}->{to_ent_name}" @@ -383,7 +373,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, if ebd is None: async with chat_limiter: with trio.fail_after(3 if enable_timeout_assertion else 300000000): - ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"])) + ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"])) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, txt, ebd) assert ebd is not None @@ -407,12 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: - conds = { - "fields": ["source_id"], - "removed_kwd": "N", - "size": 1, - "knowledge_graph_kwd": ["graph"] - } + conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) doc_ids = [] if res.total == 0: @@ -423,12 +408,8 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph(tenant_id, kb_id, exclude_rebuild=None): - conds = { - "fields": ["content_with_weight", "removed_kwd", "source_id"], - "size": 1, - "knowledge_graph_kwd": ["graph"] - } - res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) + conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} + res = await trio.to_thread.run_sync(settings.retrievaler.search, conds, search.index_name(tenant_id), [kb_id]) if not res.total == 0: for id in res.ids: try: @@ -449,56 +430,63 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang global chat_limiter start = trio.current_time() - await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)) + await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) if change.removed_nodes: - await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)) - + await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) if change.removed_edges: + async def del_edges(from_node, to_node): async with chat_limiter: - await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)) + await trio.to_thread.run_sync( + settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id + ) + async with trio.open_nursery() as nursery: for from_node, to_node in change.removed_edges: - nursery.start_soon(del_edges, from_node, to_node) + nursery.start_soon(del_edges, from_node, to_node) now = trio.current_time() if callback: callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") start = now - chunks = [{ - "id": get_uuid(), - "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False), - "knowledge_graph_kwd": "graph", - "kb_id": kb_id, - "source_id": graph.graph.get("source_id", []), - "available_int": 0, - "removed_kwd": "N" - }] - + chunks = [ + { + "id": get_uuid(), + "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False), + "knowledge_graph_kwd": "graph", + "kb_id": kb_id, + "source_id": graph.graph.get("source_id", []), + "available_int": 0, + "removed_kwd": "N", + } + ] + # generate updated subgraphs for source in graph.graph["source_id"]: subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy() subgraph.graph["source_id"] = [source] for n in subgraph.nodes: subgraph.nodes[n]["source_id"] = [source] - chunks.append({ - "id": get_uuid(), - "content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False), - "knowledge_graph_kwd": "subgraph", - "kb_id": kb_id, - "source_id": [source], - "available_int": 0, - "removed_kwd": "N" - }) + chunks.append( + { + "id": get_uuid(), + "content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False), + "knowledge_graph_kwd": "subgraph", + "kb_id": kb_id, + "source_id": [source], + "available_int": 0, + "removed_kwd": "N", + } + ) async with trio.open_nursery() as nursery: for ii, node in enumerate(change.added_updated_nodes): node_attrs = graph.nodes[node] nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks) - if ii%100 == 9 and callback: + if ii % 100 == 9 and callback: callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") async with trio.open_nursery() as nursery: @@ -508,7 +496,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. continue nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) - if ii%100 == 9 and callback: + if ii % 100 == 9 and callback: callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") now = trio.current_time() @@ -516,11 +504,11 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") start = now - enable_timeout_assertion=os.environ.get("ENABLE_TIMEOUT_ASSERTION") + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): with trio.fail_after(3 if enable_timeout_assertion else 30000000): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) if b % 100 == es_bulk_size and callback: callback(msg=f"Insert chunks: {b}/{len(chunks)}") if doc_store_result: @@ -544,10 +532,10 @@ def is_continuous_subsequence(subseq, seq): break return indexes - index_list = find_all_indexes(seq,subseq[0]) + index_list = find_all_indexes(seq, subseq[0]) for idx in index_list: - if idx!=len(seq)-1: - if seq[idx+1]==subseq[-1]: + if idx != len(seq) - 1: + if seq[idx + 1] == subseq[-1]: return True return False @@ -574,10 +562,7 @@ def merge_tuples(list1, list2): async def get_entity_type2sampels(idxnms, kb_ids: list): - es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, - "size": 10000, - "fields": ["content_with_weight"]}, - idxnms, kb_ids)) + es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) res = defaultdict(list) for id in es_res.ids: @@ -609,13 +594,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): graph = nx.Graph() flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] bs = 256 - for i in range(0, 1024*bs, bs): - es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [], - {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, - [], - OrderByExpr(), - i, bs, search.index_name(tenant_id), [kb_id] - )) + for i in range(0, 1024 * bs, bs): + es_res = await trio.to_thread.run_sync( + lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) + ) # tot = settings.docStoreConn.getTotal(es_res) es_res = settings.docStoreConn.getFields(es_res, flds) @@ -629,13 +611,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): continue elif exclude_rebuild in d["source_id"]: continue - + next_graph = json_graph.node_link_graph(json.loads(d["content_with_weight"]), edges="edges") merged_graph = nx.compose(graph, next_graph) - merged_source = { - n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"] - for n in graph.nodes & next_graph.nodes - } + merged_source = {n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"] for n in graph.nodes & next_graph.nodes} nx.set_node_attributes(merged_graph, merged_source, "source_id") if "source_id" in graph.graph: merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"] diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 5575aa390..17a6ccea9 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -239,7 +239,7 @@ class Base(ABC): def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) ans = "" @@ -293,7 +293,7 @@ class Base(ABC): assert False, "Shouldn't be here." def chat(self, system, history, gen_conf={}, **kwargs): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -324,7 +324,7 @@ class Base(ABC): def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) total_tokens = 0 @@ -427,7 +427,7 @@ class Base(ABC): assert False, "Shouldn't be here." def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) ans = "" @@ -576,7 +576,7 @@ class BaiChuanChat(Base): return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -641,7 +641,7 @@ class ZhipuChat(Base): return super().chat_with_tools(system, history, gen_conf) def chat_streamly(self, system, history, gen_conf={}, **kwargs): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -705,7 +705,7 @@ class LocalLLM(Base): def _prepare_prompt(self, system, history, gen_conf): from rag.svr.jina_server import Prompt - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) return Prompt(message=history, gen_conf=gen_conf) @@ -792,7 +792,7 @@ class MiniMaxChat(Base): return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: @@ -865,7 +865,7 @@ class MistralChat(Base): return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: @@ -1089,7 +1089,7 @@ class HunyuanChat(Base): _gen_conf = {} _history = [{k.capitalize(): v for k, v in item.items()} for item in history] - if system: + if system and history and history[0].get("role") != "system": _history.insert(0, {"Role": "system", "Content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1565,7 +1565,7 @@ class LiteLLMBase(ABC): def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) ans = "" @@ -1630,7 +1630,7 @@ class LiteLLMBase(ABC): assert False, "Shouldn't be here." def chat(self, system, history, gen_conf={}, **kwargs): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -1662,7 +1662,7 @@ class LiteLLMBase(ABC): def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) total_tokens = 0 @@ -1787,7 +1787,7 @@ class LiteLLMBase(ABC): assert False, "Shouldn't be here." def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): - if system: + if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) ans = "" diff --git a/sandbox/executor_manager/services/execution.py b/sandbox/executor_manager/services/execution.py index 1371ee95f..eae366585 100644 --- a/sandbox/executor_manager/services/execution.py +++ b/sandbox/executor_manager/services/execution.py @@ -162,7 +162,7 @@ if (fs.existsSync(mainPath)) { elif language == SupportLanguage.NODEJS: run_args.extend([]) else: - assert True, "Will never reach here" + assert False, "Will never reach here" run_args.extend([runner_name, args_json]) returncode, stdout, stderr = await async_run_command(