mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-07 02:55:08 +08:00
Compare commits
2 Commits
14273b4595
...
c7efaab30e
| Author | SHA1 | Date | |
|---|---|---|---|
| c7efaab30e | |||
| ff49454501 |
@ -153,6 +153,16 @@ class Graph:
|
|||||||
def get_tenant_id(self):
|
def get_tenant_id(self):
|
||||||
return self._tenant_id
|
return self._tenant_id
|
||||||
|
|
||||||
|
def get_variable_value(self, exp: str) -> Any:
|
||||||
|
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||||
|
if exp.find("@") < 0:
|
||||||
|
return self.globals[exp]
|
||||||
|
cpn_id, var_nm = exp.split("@")
|
||||||
|
cpn = self.get_component(cpn_id)
|
||||||
|
if not cpn:
|
||||||
|
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||||
|
return cpn["obj"].output(var_nm)
|
||||||
|
|
||||||
|
|
||||||
class Canvas(Graph):
|
class Canvas(Graph):
|
||||||
|
|
||||||
@ -406,16 +416,6 @@ class Canvas(Graph):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_variable_value(self, exp: str) -> Any:
|
|
||||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
|
||||||
if exp.find("@") < 0:
|
|
||||||
return self.globals[exp]
|
|
||||||
cpn_id, var_nm = exp.split("@")
|
|
||||||
cpn = self.get_component(cpn_id)
|
|
||||||
if not cpn:
|
|
||||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
|
||||||
return cpn["obj"].output(var_nm)
|
|
||||||
|
|
||||||
def get_history(self, window_size):
|
def get_history(self, window_size):
|
||||||
convs = []
|
convs = []
|
||||||
if window_size <= 0:
|
if window_size <= 0:
|
||||||
|
|||||||
@ -102,6 +102,8 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
def get_input_elements(self) -> dict[str, Any]:
|
def get_input_elements(self) -> dict[str, Any]:
|
||||||
res = self.get_input_elements_from_text(self._param.sys_prompt)
|
res = self.get_input_elements_from_text(self._param.sys_prompt)
|
||||||
|
if isinstance(self._param.prompts, str):
|
||||||
|
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||||
for prompt in self._param.prompts:
|
for prompt in self._param.prompts:
|
||||||
d = self.get_input_elements_from_text(prompt["content"])
|
d = self.get_input_elements_from_text(prompt["content"])
|
||||||
res.update(d)
|
res.update(d)
|
||||||
@ -114,6 +116,8 @@ class LLM(ComponentBase):
|
|||||||
self._param.sys_prompt += txt
|
self._param.sys_prompt += txt
|
||||||
|
|
||||||
def _sys_prompt_and_msg(self, msg, args):
|
def _sys_prompt_and_msg(self, msg, args):
|
||||||
|
if isinstance(self._param.prompts, str):
|
||||||
|
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||||
for p in self._param.prompts:
|
for p in self._param.prompts:
|
||||||
if msg and msg[-1]["role"] == p["role"]:
|
if msg and msg[-1]["role"] == p["role"]:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -545,9 +545,6 @@ def run_graphrag():
|
|||||||
if task and task.progress not in [-1, 1]:
|
if task and task.progress not in [-1, 1]:
|
||||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||||
|
|
||||||
document_ids = []
|
|
||||||
sample_document = {}
|
|
||||||
|
|
||||||
documents, _ = DocumentService.get_by_kb_id(
|
documents, _ = DocumentService.get_by_kb_id(
|
||||||
kb_id=kb_id,
|
kb_id=kb_id,
|
||||||
page_number=0,
|
page_number=0,
|
||||||
@ -559,13 +556,11 @@ def run_graphrag():
|
|||||||
types=[],
|
types=[],
|
||||||
suffix=[],
|
suffix=[],
|
||||||
)
|
)
|
||||||
for document in documents:
|
if not documents:
|
||||||
|
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||||
|
|
||||||
if not sample_document and document["parser_config"].get("graphrag", {}).get("use_graphrag", False):
|
sample_document = documents[0]
|
||||||
sample_document = document
|
document_ids = [document["id"] for document in documents]
|
||||||
document_ids.insert(0, document["id"])
|
|
||||||
else:
|
|
||||||
document_ids.append(document["id"])
|
|
||||||
|
|
||||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||||
|
|
||||||
@ -586,7 +581,6 @@ def trace_graphrag():
|
|||||||
if not ok:
|
if not ok:
|
||||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
|
||||||
task_id = kb.graphrag_task_id
|
task_id = kb.graphrag_task_id
|
||||||
if not task_id:
|
if not task_id:
|
||||||
return get_error_data_result(message="GraphRAG Task ID Not Found")
|
return get_error_data_result(message="GraphRAG Task ID Not Found")
|
||||||
@ -619,9 +613,6 @@ def run_raptor():
|
|||||||
if task and task.progress not in [-1, 1]:
|
if task and task.progress not in [-1, 1]:
|
||||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||||
|
|
||||||
document_ids = []
|
|
||||||
sample_document = {}
|
|
||||||
|
|
||||||
documents, _ = DocumentService.get_by_kb_id(
|
documents, _ = DocumentService.get_by_kb_id(
|
||||||
kb_id=kb_id,
|
kb_id=kb_id,
|
||||||
page_number=0,
|
page_number=0,
|
||||||
@ -633,13 +624,11 @@ def run_raptor():
|
|||||||
types=[],
|
types=[],
|
||||||
suffix=[],
|
suffix=[],
|
||||||
)
|
)
|
||||||
for document in documents:
|
if not documents:
|
||||||
|
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||||
|
|
||||||
if not sample_document:
|
sample_document = documents[0]
|
||||||
sample_document = document
|
document_ids = [document["id"] for document in documents]
|
||||||
document_ids.insert(0, document["id"])
|
|
||||||
else:
|
|
||||||
document_ids.append(document["id"])
|
|
||||||
|
|
||||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||||
|
|
||||||
@ -660,7 +649,6 @@ def trace_raptor():
|
|||||||
if not ok:
|
if not ok:
|
||||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
|
||||||
task_id = kb.raptor_task_id
|
task_id = kb.raptor_task_id
|
||||||
if not task_id:
|
if not task_id:
|
||||||
return get_error_data_result(message="RAPTOR Task ID Not Found")
|
return get_error_data_result(message="RAPTOR Task ID Not Found")
|
||||||
|
|||||||
@ -285,7 +285,7 @@ async def run_graphrag_for_kb(
|
|||||||
|
|
||||||
if not with_resolution and not with_community:
|
if not with_resolution and not with_community:
|
||||||
now = trio.current_time()
|
now = trio.current_time()
|
||||||
callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||||
|
|
||||||
await kb_lock.spin_acquire()
|
await kb_lock.spin_acquire()
|
||||||
|
|||||||
@ -14,9 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
|
||||||
|
|
||||||
class ExtractorParam(LLMParam):
|
class ExtractorParam(ProcessParamBase, LLMParam):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.field_name = ""
|
self.field_name = ""
|
||||||
@ -26,7 +27,7 @@ class ExtractorParam(LLMParam):
|
|||||||
self.check_empty(self.field_name, "Result Destination")
|
self.check_empty(self.field_name, "Result Destination")
|
||||||
|
|
||||||
|
|
||||||
class Extractor(LLM):
|
class Extractor(ProcessBase, LLM):
|
||||||
component_name = "Extractor"
|
component_name = "Extractor"
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class ExtractorFromUpstream(BaseModel):
|
|||||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
text_result: str | None = Field(default=None, alias="text")
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
html_result: list[str] | None = Field(default=None, alias="html")
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class HierarchicalMergerFromUpstream(BaseModel):
|
|||||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
text_result: str | None = Field(default=None, alias="text")
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
html_result: list[str] | None = Field(default=None, alias="html")
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
|||||||
@ -143,6 +143,10 @@ class Pipeline(Graph):
|
|||||||
async def invoke():
|
async def invoke():
|
||||||
nonlocal last_cpn, cpn_obj
|
nonlocal last_cpn, cpn_obj
|
||||||
await cpn_obj.invoke(**last_cpn.output())
|
await cpn_obj.invoke(**last_cpn.output())
|
||||||
|
#if inspect.iscoroutinefunction(cpn_obj.invoke):
|
||||||
|
# await cpn_obj.invoke(**last_cpn.output())
|
||||||
|
#else:
|
||||||
|
# cpn_obj.invoke(**last_cpn.output())
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
nursery.start_soon(invoke)
|
nursery.start_soon(invoke)
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class SplitterFromUpstream(BaseModel):
|
|||||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
text_result: str | None = Field(default=None, alias="text")
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
html_result: list[str] | None = Field(default=None, alias="html")
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class TokenizerFromUpstream(BaseModel):
|
|||||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
text_result: str | None = Field(default=None, alias="text")
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
html_result: list[str] | None = Field(default=None, alias="html")
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class Tokenizer(ProcessBase):
|
|||||||
if ck.get("questions"):
|
if ck.get("questions"):
|
||||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||||
if ck.get("keywords"):
|
if ck.get("keywords"):
|
||||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
ck["important_tks"] = rag_tokenizer.tokenize(",".join(ck["keywords"]))
|
||||||
if ck.get("summary"):
|
if ck.get("summary"):
|
||||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
|
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
|
|||||||
@ -222,10 +222,9 @@ async def collect():
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
canceled = False
|
canceled = False
|
||||||
if msg.get("doc_id", "") == GRAPH_RAPTOR_FAKE_DOC_ID:
|
if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]:
|
||||||
task = msg
|
task = msg
|
||||||
if task["task_type"] in ["graphrag", "raptor"] and msg.get("doc_ids", []):
|
if task["task_type"] in ["graphrag", "raptor"] and msg.get("doc_ids", []):
|
||||||
print(f"hack {msg['doc_ids']=}=",flush=True)
|
|
||||||
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
||||||
task["doc_ids"] = msg["doc_ids"]
|
task["doc_ids"] = msg["doc_ids"]
|
||||||
else:
|
else:
|
||||||
@ -637,9 +636,11 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
|||||||
|
|
||||||
|
|
||||||
@timeout(3600)
|
@timeout(3600)
|
||||||
async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
|
async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
|
||||||
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
|
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
|
|
||||||
|
raptor_config = kb_parser_config.get("raptor", {})
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
vctr_nm = "q_%d_vec"%vector_size
|
vctr_nm = "q_%d_vec"%vector_size
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
@ -649,12 +650,12 @@ async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None,
|
|||||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||||
|
|
||||||
raptor = Raptor(
|
raptor = Raptor(
|
||||||
row["parser_config"]["raptor"].get("max_cluster", 64),
|
raptor_config.get("max_cluster", 64),
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
row["parser_config"]["raptor"]["prompt"],
|
raptor_config["prompt"],
|
||||||
row["parser_config"]["raptor"]["max_token"],
|
raptor_config["max_token"],
|
||||||
row["parser_config"]["raptor"]["threshold"]
|
raptor_config["threshold"],
|
||||||
)
|
)
|
||||||
original_length = len(chunks)
|
original_length = len(chunks)
|
||||||
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||||
@ -773,6 +774,15 @@ async def do_handle_task(task):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if task_type == "raptor":
|
if task_type == "raptor":
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
||||||
|
if not ok:
|
||||||
|
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for RAPTOR task")
|
||||||
|
return
|
||||||
|
|
||||||
|
kb_parser_config = kb.parser_config
|
||||||
|
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
|
||||||
|
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||||
|
return
|
||||||
# bind LLM for raptor
|
# bind LLM for raptor
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
# run RAPTOR
|
# run RAPTOR
|
||||||
@ -780,6 +790,7 @@ async def do_handle_task(task):
|
|||||||
# chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
# chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||||
chunks, token_count = await run_raptor_for_kb(
|
chunks, token_count = await run_raptor_for_kb(
|
||||||
row=task,
|
row=task,
|
||||||
|
kb_parser_config=kb_parser_config,
|
||||||
chat_mdl=chat_model,
|
chat_mdl=chat_model,
|
||||||
embd_mdl=embedding_model,
|
embd_mdl=embedding_model,
|
||||||
vector_size=vector_size,
|
vector_size=vector_size,
|
||||||
@ -788,10 +799,17 @@ async def do_handle_task(task):
|
|||||||
)
|
)
|
||||||
# Either using graphrag or Standard chunking methods
|
# Either using graphrag or Standard chunking methods
|
||||||
elif task_type == "graphrag":
|
elif task_type == "graphrag":
|
||||||
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
||||||
progress_callback(prog=-1.0, msg="Internal configuration error.")
|
if not ok:
|
||||||
|
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for GraphRAG task")
|
||||||
return
|
return
|
||||||
graphrag_conf = task["kb_parser_config"].get("graphrag", {})
|
|
||||||
|
kb_parser_config = kb.parser_config
|
||||||
|
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||||
|
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
|
||||||
|
return
|
||||||
|
|
||||||
|
graphrag_conf = kb_parser_config.get("graphrag", {})
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
with_resolution = graphrag_conf.get("resolution", False)
|
with_resolution = graphrag_conf.get("resolution", False)
|
||||||
@ -802,7 +820,7 @@ async def do_handle_task(task):
|
|||||||
row=task,
|
row=task,
|
||||||
doc_ids=task.get("doc_ids", []),
|
doc_ids=task.get("doc_ids", []),
|
||||||
language=task_language,
|
language=task_language,
|
||||||
kb_parser_config=task_parser_config,
|
kb_parser_config=kb_parser_config,
|
||||||
chat_model=chat_model,
|
chat_model=chat_model,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
callback=progress_callback,
|
callback=progress_callback,
|
||||||
|
|||||||
Reference in New Issue
Block a user