### What problem does this PR solve?

#4367

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-09 17:07:21 +08:00
committed by GitHub
parent f892d7d426
commit c5da3cdd97
30 changed files with 736 additions and 202 deletions

View File

@ -16,10 +16,10 @@
# from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import random
import sys
from api.utils.log_utils import initRootLogger
from graphrag.utils import get_llm_cache, set_llm_cache
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
@ -44,7 +44,7 @@ import numpy as np
from peewee import DoesNotExist
from api.db import LLMType, ParserType, TaskStatus
from api.db.services.dialog_service import keyword_extraction, question_proposal
from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
@ -53,10 +53,10 @@ from api import settings
from api.versions import get_ragflow_version
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
knowledge_graph, email, tag
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
@ -78,7 +78,8 @@ FACTORY = {
ParserType.ONE.value: one,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email,
ParserType.KG.value: knowledge_graph
ParserType.KG.value: knowledge_graph,
ParserType.TAG.value: tag
}
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
@ -199,7 +200,8 @@ def build_chunks(task, progress_callback):
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
raise
except Exception as e:
if re.search("(No such file|not found)", str(e)):
@ -227,7 +229,7 @@ def build_chunks(task, progress_callback):
"kb_id": str(task["kb_id"])
}
if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"])
doc[PAGERANK_FLD] = int(task["pagerank"])
el = 0
for ck in cks:
d = copy.deepcopy(doc)
@ -252,7 +254,8 @@ def build_chunks(task, progress_callback):
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st
except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
@ -295,12 +298,43 @@ def build_chunks(task, progress_callback):
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
progress_callback(msg="Start to tag for every chunk ...")
kb_ids = task["kb_parser_config"]["tag_kb_ids"]
tenant_id = task["tenant_id"]
topn_tags = task["kb_parser_config"].get("topn_tags", 3)
S = 1000
st = timer()
examples = []
all_tags = get_tags_from_cache(kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags)
else:
all_tags = json.loads(all_tags)
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs:
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
continue
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
random.choices(examples, k=2) if len(examples)>2 else examples,
topn=topn_tags)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
return docs
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size)
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
@ -381,7 +415,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"])
doc[PAGERANK_FLD] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:
@ -480,7 +514,8 @@ def do_handle_task(task):
doc_store_result = ""
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
task_dataset_id)
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@ -493,15 +528,21 @@ def do_handle_task(task):
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
task_dataset_id)
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
token_count, time_cost))
def handle_task():