mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Tagging (#4426)
### What problem does this PR solve? #4367 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -16,10 +16,10 @@
|
||||
# from beartype import BeartypeConf
|
||||
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
import random
|
||||
import sys
|
||||
from api.utils.log_utils import initRootLogger
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||
@ -44,7 +44,7 @@ import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
|
||||
from api.db import LLMType, ParserType, TaskStatus
|
||||
from api.db.services.dialog_service import keyword_extraction, question_proposal
|
||||
from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService
|
||||
@ -53,10 +53,10 @@ from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
knowledge_graph, email
|
||||
knowledge_graph, email, tag
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string
|
||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
@ -78,7 +78,8 @@ FACTORY = {
|
||||
ParserType.ONE.value: one,
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email,
|
||||
ParserType.KG.value: knowledge_graph
|
||||
ParserType.KG.value: knowledge_graph,
|
||||
ParserType.TAG.value: tag
|
||||
}
|
||||
|
||||
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
||||
@ -199,7 +200,8 @@ def build_chunks(task, progress_callback):
|
||||
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
||||
except TimeoutError:
|
||||
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
|
||||
logging.exception(
|
||||
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
|
||||
raise
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
@ -227,7 +229,7 @@ def build_chunks(task, progress_callback):
|
||||
"kb_id": str(task["kb_id"])
|
||||
}
|
||||
if task["pagerank"]:
|
||||
doc["pagerank_fea"] = int(task["pagerank"])
|
||||
doc[PAGERANK_FLD] = int(task["pagerank"])
|
||||
el = 0
|
||||
for ck in cks:
|
||||
d = copy.deepcopy(doc)
|
||||
@ -252,7 +254,8 @@ def build_chunks(task, progress_callback):
|
||||
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
|
||||
el += timer() - st
|
||||
except Exception:
|
||||
logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
|
||||
logging.exception(
|
||||
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
|
||||
raise
|
||||
|
||||
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
|
||||
@ -295,12 +298,43 @@ def build_chunks(task, progress_callback):
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
if task["kb_parser_config"].get("tag_kb_ids", []):
|
||||
progress_callback(msg="Start to tag for every chunk ...")
|
||||
kb_ids = task["kb_parser_config"]["tag_kb_ids"]
|
||||
tenant_id = task["tenant_id"]
|
||||
topn_tags = task["kb_parser_config"].get("topn_tags", 3)
|
||||
S = 1000
|
||||
st = timer()
|
||||
examples = []
|
||||
all_tags = get_tags_from_cache(kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
set_tags_to_cache(kb_ids, all_tags)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
for d in docs:
|
||||
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
continue
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
||||
if not cached:
|
||||
cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
|
||||
random.choices(examples, k=2) if len(examples)>2 else examples,
|
||||
topn=topn_tags)
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||
d[TAG_FLD] = json.loads(cached)
|
||||
|
||||
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size)
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -381,7 +415,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
"title_tks": rag_tokenizer.tokenize(row["name"])
|
||||
}
|
||||
if row["pagerank"]:
|
||||
doc["pagerank_fea"] = int(row["pagerank"])
|
||||
doc[PAGERANK_FLD] = int(row["pagerank"])
|
||||
res = []
|
||||
tk_count = 0
|
||||
for content, vctr in chunks[original_length:]:
|
||||
@ -480,7 +514,8 @@ def do_handle_task(task):
|
||||
doc_store_result = ""
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
|
||||
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
|
||||
task_dataset_id)
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
@ -493,15 +528,21 @@ def do_handle_task(task):
|
||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
||||
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)
|
||||
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
|
||||
task_dataset_id)
|
||||
return
|
||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
|
||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
timer() - start_ts))
|
||||
|
||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
|
||||
logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost))
|
||||
logging.info(
|
||||
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
token_count, time_cost))
|
||||
|
||||
|
||||
def handle_task():
|
||||
|
||||
Reference in New Issue
Block a user