diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 42d8389fa..f8c10357f 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -397,9 +397,10 @@ class KnowledgebaseService(CommonService): else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) + total = kbs.count() kbs = kbs.paginate(page_number, items_per_page) - return list(kbs.dicts()), kbs.count() + return list(kbs.dicts()), total @classmethod @DB.connection_context() diff --git a/rag/flow/hierarchical_merger/hierarchical_merger.py b/rag/flow/hierarchical_merger/hierarchical_merger.py index dda2bcfa7..e7b8b9def 100644 --- a/rag/flow/hierarchical_merger/hierarchical_merger.py +++ b/rag/flow/hierarchical_merger/hierarchical_merger.py @@ -166,7 +166,7 @@ class HierarchicalMerger(ProcessBase): img = None for i in path: txt += lines[i] + "\n" - concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get))) + concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id))) cks.append(txt) images.append(img) @@ -180,7 +180,7 @@ class HierarchicalMerger(ProcessBase): ] async with trio.open_nursery() as nursery: for d in cks: - nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid()) + nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) self.set_output("chunks", cks) self.callback(1, "Done.") diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index b1a34c59b..0784ce6fa 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -512,4 +512,4 @@ class Parser(ProcessBase): outs = self.output() async with trio.open_nursery() as nursery: for d in outs.get("json", []): - nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid()) + nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index 9c6eb7bfd..24f62b6f1 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -87,7 +87,7 @@ class Splitter(ProcessBase): sections, section_images = [], [] for o in from_upstream.json_result or []: sections.append((o.get("text", ""), o.get("position_tag", ""))) - section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get))) + section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id))) chunks, images = naive_merge_with_images( sections, @@ -106,6 +106,6 @@ class Splitter(ProcessBase): ] async with trio.open_nursery() as nursery: for d in cks: - nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid()) + nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) self.set_output("chunks", cks) self.callback(1, "Done.") diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index d2762ec91..1b593bf6b 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -680,8 +680,7 @@ async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None): chat_mdl, gen_conf={"temperature": 0.0, "top_p": 0.9} ) - print(ans, "::::::::::::::::::::::::::::::::::::", flush=True) - txt_info["toc"] = ans if ans else [] + txt_info["toc"] = ans if ans and not isinstance(ans, str) else [] if callback: callback(msg="") except Exception as e: @@ -728,8 +727,6 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): for chunk in chunks_res: titles.extend(chunk.get("toc", [])) - - print(titles, ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") # Filter out entries with title == -1 prune = len(titles) > 512 @@ -745,12 +742,16 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): filtered.append(x) logging.info(f"\n\nFiltered TOC sections:\n{filtered}") + if not filtered: + return [] # Generate initial level (level/title) raw_structure = [x.get("title", "") for x in filtered] # Assign hierarchy levels using LLM toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) + if not toc_with_levels: + return [] # Merge structure and content (by index) prune = len(toc_with_levels) > 512 @@ -779,7 +780,6 @@ def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): chat_mdl, gen_conf={"temperature": 0.0, "top_p": 0.9} ) - print(ans, "::::::::::::::::::::::::::::::::::::", flush=True) id2score = {} for ti, sc in zip(toc, ans): if not isinstance(sc, dict) or sc.get("score", -1) < 1: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 9801b53dd..14a0b5cda 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import concurrent # from beartype import BeartypeConf # from beartype.claw import beartype_all # <-- you didn't sign up for this # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code @@ -317,7 +317,7 @@ async def build_chunks(task, progress_callback): d["img_id"] = "" docs.append(d) return - await image2id(d, partial(STORAGE_IMPL.put), d["id"], task["kb_id"]) + await image2id(d, partial(STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"]) docs.append(d) except Exception: logging.exception( @@ -370,38 +370,6 @@ async def build_chunks(task, progress_callback): nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) - if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False): - progress_callback(msg="Start to generate table of content ...") - chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) - docs = sorted(docs, key=lambda d:( - d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), - d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) - )) - toc: list[dict] = await run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback) - logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) - ii = 0 - while ii < len(toc): - try: - idx = int(toc[ii]["chunk_id"]) - del toc[ii]["chunk_id"] - toc[ii]["ids"] = [docs[idx]["id"]] - if ii == len(toc) -1: - break - for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1): - toc[ii]["ids"].append(docs[jj]["id"]) - except Exception as e: - logging.exception(e) - ii += 1 - - if toc: - d = copy.deepcopy(docs[-1]) - d["content_with_weight"] = json.dumps(toc, ensure_ascii=False) - d["toc_kwd"] = "toc" - d["available_int"] = 0 - d["page_num_int"] = 100000000 - d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() - docs.append(d) - if task["kb_parser_config"].get("tag_kb_ids", []): progress_callback(msg="Start to tag for every chunk ...") kb_ids = task["kb_parser_config"]["tag_kb_ids"] @@ -451,6 +419,39 @@ async def build_chunks(task, progress_callback): return docs +def build_TOC(task, docs, progress_callback): + progress_callback(msg="Start to generate table of content ...") + chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) + docs = sorted(docs, key=lambda d:( + d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), + d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) + )) + toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback) + logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) + ii = 0 + while ii < len(toc): + try: + idx = int(toc[ii]["chunk_id"]) + del toc[ii]["chunk_id"] + toc[ii]["ids"] = [docs[idx]["id"]] + if ii == len(toc) -1: + break + for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1): + toc[ii]["ids"].append(docs[jj]["id"]) + except Exception as e: + logging.exception(e) + ii += 1 + + if toc: + d = copy.deepcopy(docs[-1]) + d["content_with_weight"] = json.dumps(toc, ensure_ascii=False) + d["toc_kwd"] = "toc" + d["available_int"] = 0 + d["page_num_int"] = 100000000 + d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() + return d + + def init_kb(row, vector_size: int): idxnm = search.index_name(row["tenant_id"]) return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) @@ -753,7 +754,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c return True -@timeout(60*60*2, 1) +@timeout(60*60*3, 1) async def do_handle_task(task): task_type = task.get("task_type", "") @@ -773,6 +774,8 @@ async def do_handle_task(task): task_document_name = task["name"] task_parser_config = task["parser_config"] task_start_ts = timer() + toc_thread = None + executor = concurrent.futures.ThreadPoolExecutor() # prepare the progress callback function progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) @@ -905,8 +908,6 @@ async def do_handle_task(task): if not chunks: progress_callback(1., msg=f"No chunk built from {task_document_name}") return - # TODO: exception handler - ## set_progress(task["did"], -1, "ERROR: ") progress_callback(msg="Generate {} chunks".format(len(chunks))) start_ts = timer() try: @@ -920,6 +921,8 @@ async def do_handle_task(task): progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts) logging.info(progress_message) progress_callback(msg=progress_message) + if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False): + toc_thread = executor.submit(build_TOC,task, chunks, progress_callback) chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer() @@ -934,8 +937,17 @@ async def do_handle_task(task): DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) time_cost = timer() - start_ts + progress_callback(msg="Indexing done ({:.2f}s).".format(time_cost)) + if toc_thread: + d = toc_thread.result() + if d: + e = await insert_es(task_id, task_tenant_id, task_dataset_id, [d], progress_callback) + if not e: + return + DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0) + task_time_cost = timer() - task_start_ts - progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) + progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost)) logging.info( "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index c26e5606d..9d1aaccf2 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -60,7 +60,7 @@ class RAGFlowMinio: ) return r - def put(self, bucket, fnm, binary): + def put(self, bucket, fnm, binary, tenant_id=None): for _ in range(3): try: if not self.conn.bucket_exists(bucket): @@ -76,13 +76,13 @@ class RAGFlowMinio: self.__open__() time.sleep(1) - def rm(self, bucket, fnm): + def rm(self, bucket, fnm, tenant_id=None): try: self.conn.remove_object(bucket, fnm) except Exception: logging.exception(f"Fail to remove {bucket}/{fnm}:") - def get(self, bucket, filename): + def get(self, bucket, filename, tenant_id=None): for _ in range(1): try: r = self.conn.get_object(bucket, filename) @@ -93,7 +93,7 @@ class RAGFlowMinio: time.sleep(1) return - def obj_exist(self, bucket, filename): + def obj_exist(self, bucket, filename, tenant_id=None): try: if not self.conn.bucket_exists(bucket): return False @@ -121,7 +121,7 @@ class RAGFlowMinio: logging.exception(f"bucket_exist {bucket} got exception") return False - def get_presigned_url(self, bucket, fnm, expires): + def get_presigned_url(self, bucket, fnm, expires, tenant_id=None): for _ in range(10): try: return self.conn.get_presigned_url("GET", bucket, fnm, expires)