diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index d7d874d98..e3c911260 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -13,7 +13,7 @@ import trio from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT -from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter +from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter, split_string_by_multi_markers from rag.llm.chat_model import Base as CompletionLLM import networkx as nx from rag.utils import num_tokens_from_string @@ -121,8 +121,7 @@ class GraphExtractor(Extractor): # Repeat to ensure we maximize entity count for i in range(self._max_gleanings): - text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) - history.append({"role": "user", "content": text}) + history.append({"role": "user", "content": CONTINUE_PROMPT}) async with chat_limiter: response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf)) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) @@ -138,11 +137,19 @@ class GraphExtractor(Extractor): token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) if continuation != "YES": break - record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER) - tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER) - records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)] - records = [r for r in records if r.strip()] - maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter) + + records = split_string_by_multi_markers( + results, + [self._prompt_variables[self._record_delimiter_key], self._prompt_variables[self._completion_delimiter_key]], + ) + rcds = [] + for record in records: + record = re.search(r"\((.*)\)", record) + if record is None: + continue + rcds.append(record.group(1)) + records = rcds + maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._prompt_variables[self._tuple_delimiter_key]) out_results.append((maybe_nodes, maybe_edges, token_count)) if self.callback: self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.") diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 8b63b0d02..dabb8a098 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -343,7 +343,7 @@ async def extract_community( "entities_kwd": stru["entities"], "important_kwd": stru["entities"], "kb_id": kb_id, - "source_id": doc_ids, + "source_id": list(doc_ids), "available_int": 0, } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(