mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: debug pipeline... (#10311)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -20,6 +20,9 @@ import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import json_repair
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
@ -57,7 +60,7 @@ from api.versions import get_ragflow_version
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
email, tag
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.nlp import search, rag_tokenizer, add_positions
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
@ -477,6 +480,8 @@ async def run_dataflow(task: dict):
|
||||
dataflow_id = task["dataflow_id"]
|
||||
doc_id = task["doc_id"]
|
||||
task_id = task["id"]
|
||||
task_dataset_id = task["kb_id"]
|
||||
|
||||
if task["task_type"] == "dataflow":
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
assert e, "User pipeline not found."
|
||||
@ -486,12 +491,12 @@ async def run_dataflow(task: dict):
|
||||
assert e, "Pipeline log not found."
|
||||
dsl = pipeline_log.dsl
|
||||
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
|
||||
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
|
||||
chunks = await pipeline.run(file=task["file"]) if task.get("file") else await pipeline.run()
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
return
|
||||
|
||||
if not chunks:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||
@ -508,7 +513,7 @@ async def run_dataflow(task: dict):
|
||||
|
||||
keys = [k for o in chunks for k in list(o.keys())]
|
||||
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
|
||||
set_progress(task_id, prog=0.82, msg="Start to embedding...")
|
||||
set_progress(task_id, prog=0.82, msg="\n-------------------------------------\nStart to embedding...")
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
@ -518,7 +523,7 @@ async def run_dataflow(task: dict):
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE+1)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
@ -529,7 +534,8 @@ async def run_dataflow(task: dict):
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
if i % (len(texts)//EMBEDDING_BATCH_SIZE/100+1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
@ -539,9 +545,23 @@ async def run_dataflow(task: dict):
|
||||
metadata = {}
|
||||
def dict_update(meta):
|
||||
nonlocal metadata
|
||||
if not meta or not isinstance(meta, dict):
|
||||
if not meta:
|
||||
return
|
||||
for k,v in meta.items():
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
meta = json_repair.loads(meta)
|
||||
except Exception:
|
||||
logging.error("Meta data format error.")
|
||||
return
|
||||
if not isinstance(meta, dict):
|
||||
return
|
||||
for k, v in meta.items():
|
||||
if isinstance(v, list):
|
||||
v = [vv for vv in v if isinstance(vv, str)]
|
||||
if not v:
|
||||
continue
|
||||
if not isinstance(v, list) and not isinstance(v, str):
|
||||
continue
|
||||
if k not in metadata:
|
||||
metadata[k] = v
|
||||
continue
|
||||
@ -561,15 +581,29 @@ async def run_dataflow(task: dict):
|
||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||
if "questions" in ck:
|
||||
if "question_tks" not in ck:
|
||||
ck["question_kwd"] = ck["questions"].split("\n")
|
||||
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
|
||||
del ck["questions"]
|
||||
if "keywords" in ck:
|
||||
if "important_tks" not in ck:
|
||||
ck["important_kwd"] = ck["keywords"].split(",")
|
||||
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
|
||||
del ck["keywords"]
|
||||
if "summary" in ck:
|
||||
if "content_ltks" not in ck:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
del ck["summary"]
|
||||
if "metadata" in ck:
|
||||
dict_update(ck["metadata"])
|
||||
del ck["metadata"]
|
||||
if "content_with_weight" not in ck:
|
||||
ck["content_with_weight"] = ck["text"]
|
||||
del ck["text"]
|
||||
if "positions" in ck:
|
||||
add_positions(ck, ck["positions"])
|
||||
del ck["positions"]
|
||||
|
||||
if metadata:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
@ -580,59 +614,18 @@ async def run_dataflow(task: dict):
|
||||
DocumentService.update_by_id(doc_id, {"meta_fields": metadata})
|
||||
|
||||
start_ts = timer()
|
||||
set_progress(task_id, prog=0.82, msg="Start to index...")
|
||||
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
|
||||
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
raptor = Raptor(
|
||||
row["parser_config"]["raptor"].get("max_cluster", 64),
|
||||
chat_mdl,
|
||||
embd_mdl,
|
||||
row["parser_config"]["raptor"]["prompt"],
|
||||
row["parser_config"]["raptor"]["max_token"],
|
||||
row["parser_config"]["raptor"]["threshold"]
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
doc = {
|
||||
"doc_id": row["doc_id"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": row["name"],
|
||||
"title_tks": rag_tokenizer.tokenize(row["name"])
|
||||
}
|
||||
if row["pagerank"]:
|
||||
doc[PAGERANK_FLD] = int(row["pagerank"])
|
||||
res = []
|
||||
tk_count = 0
|
||||
for content, vctr in chunks[original_length:]:
|
||||
d = copy.deepcopy(doc)
|
||||
d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
|
||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
d[vctr_nm] = vctr.tolist()
|
||||
d["content_with_weight"] = content
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(content)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
res.append(d)
|
||||
tk_count += num_tokens_from_string(content)
|
||||
return res, tk_count
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -787,7 +780,6 @@ async def do_handle_task(task):
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
async with kg_limiter:
|
||||
# chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
chunks, token_count = await run_raptor_for_kb(
|
||||
row=task,
|
||||
kb_parser_config=kb_parser_config,
|
||||
@ -908,8 +900,8 @@ async def handle_task():
|
||||
task_document_ids = []
|
||||
if task_type in ["graphrag", "raptor"]:
|
||||
task_document_ids = task["doc_ids"]
|
||||
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
if not task.get("dataflow_id", ""):
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
|
||||
redis_msg.ack()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user