diff --git a/rag/app/audio.py b/rag/app/audio.py
index 86fa759aa..979659382 100644
--- a/rag/app/audio.py
+++ b/rag/app/audio.py
@@ -34,7 +34,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
if not ext:
raise RuntimeError("No extension detected.")
- if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
+ if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma",
+ ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
raise RuntimeError(f"Extension {ext} is not supported yet.")
tmp_path = ""
diff --git a/rag/app/book.py b/rag/app/book.py
index bce85d84e..5f093c55b 100644
--- a/rag/app/book.py
+++ b/rag/app/book.py
@@ -22,7 +22,7 @@ from deepdoc.parser.utils import get_text
from rag.app import naive
from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
-from rag.nlp import bullets_category, is_english,remove_contents_table, \
+from rag.nlp import bullets_category, is_english, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks, attach_media_context
from rag.nlp import rag_tokenizer
@@ -91,9 +91,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
filename, binary=binary, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
- tbls = vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
+ tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs)
# tbls = [((None, lns), None) for lns in tbls]
- sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
+ sections = [(item[0], item[1] if item[1] is not None else "") for item in sections if
+ not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
@@ -109,14 +110,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser(
- filename = filename,
- binary = binary,
- from_page = from_page,
- to_page = to_page,
- lang = lang,
- callback = callback,
- pdf_cls = Pdf,
- layout_recognizer = layout_recognizer,
+ filename=filename,
+ binary=binary,
+ from_page=from_page,
+ to_page=to_page,
+ lang=lang,
+ callback=callback,
+ pdf_cls=Pdf,
+ layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
**kwargs
)
@@ -126,7 +127,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
-
+
callback(0.8, "Finish parsing.")
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
@@ -175,7 +176,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
for ck in hierarchical_merge(bull, sections, 5)]
else:
sections = [s.split("@") for s, _ in sections]
- sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ]
+ sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections]
chunks = naive_merge(
sections,
parser_config.get("chunk_token_num", 256),
@@ -199,6 +200,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
+
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
diff --git a/rag/app/email.py b/rag/app/email.py
index 6f3e30ab4..ea01a337e 100644
--- a/rag/app/email.py
+++ b/rag/app/email.py
@@ -26,13 +26,13 @@ import io
def chunk(
- filename,
- binary=None,
- from_page=0,
- to_page=100000,
- lang="Chinese",
- callback=None,
- **kwargs,
+ filename,
+ binary=None,
+ from_page=0,
+ to_page=100000,
+ lang="Chinese",
+ callback=None,
+ **kwargs,
):
"""
Only eml is supported
@@ -93,7 +93,8 @@ def chunk(
_add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
- (line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
+ (line, "") for line in
+ HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
]
st = timer()
@@ -126,7 +127,9 @@ def chunk(
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
chunk(sys.argv[1], callback=dummy)
diff --git a/rag/app/laws.py b/rag/app/laws.py
index 97b58ca15..15c43e368 100644
--- a/rag/app/laws.py
+++ b/rag/app/laws.py
@@ -29,8 +29,6 @@ from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
-
-
class Docx(DocxParser):
def __init__(self):
pass
@@ -58,37 +56,36 @@ class Docx(DocxParser):
return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
- self.doc = Document(
- filename) if not binary else Document(BytesIO(binary))
- pn = 0
- lines = []
- level_set = set()
- bull = bullets_category([p.text for p in self.doc.paragraphs])
- for p in self.doc.paragraphs:
- if pn > to_page:
- break
- question_level, p_text = docx_question_level(p, bull)
- if not p_text.strip("\n"):
+ self.doc = Document(
+ filename) if not binary else Document(BytesIO(binary))
+ pn = 0
+ lines = []
+ level_set = set()
+ bull = bullets_category([p.text for p in self.doc.paragraphs])
+ for p in self.doc.paragraphs:
+ if pn > to_page:
+ break
+ question_level, p_text = docx_question_level(p, bull)
+ if not p_text.strip("\n"):
+ continue
+ lines.append((question_level, p_text))
+ level_set.add(question_level)
+ for run in p.runs:
+ if 'lastRenderedPageBreak' in run._element.xml:
+ pn += 1
continue
- lines.append((question_level, p_text))
- level_set.add(question_level)
- for run in p.runs:
- if 'lastRenderedPageBreak' in run._element.xml:
- pn += 1
- continue
- if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
- pn += 1
+ if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
+ pn += 1
- sorted_levels = sorted(level_set)
+ sorted_levels = sorted(level_set)
- h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
- h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
+ h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
+ h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
- root = Node(level=0, depth=h2_level, texts=[])
- root.build_tree(lines)
-
- return [element for element in root.get_tree() if element]
+ root = Node(level=0, depth=h2_level, texts=[])
+ root.build_tree(lines)
+ return [element for element in root.get_tree() if element]
def __str__(self) -> str:
return f'''
@@ -121,8 +118,7 @@ class Pdf(PdfParser):
start = timer()
self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
- logging.debug("layouts:".format(
- ))
+ logging.debug("layouts: {}".format((timer() - start)))
self._naive_vertical_merge()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
@@ -154,7 +150,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
chunks = Docx()(filename, binary)
callback(0.7, "Finish parsing.")
return tokenize_chunks(chunks, doc, eng, None)
-
+
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer, parser_model_name = normalize_layout_recognizer(
parser_config.get("layout_recognize", "DeepDOC")
@@ -168,14 +164,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
raw_sections, tables, pdf_parser = parser(
- filename = filename,
- binary = binary,
- from_page = from_page,
- to_page = to_page,
- lang = lang,
- callback = callback,
- pdf_cls = Pdf,
- layout_recognizer = layout_recognizer,
+ filename=filename,
+ binary=binary,
+ from_page=from_page,
+ to_page=to_page,
+ lang=lang,
+ callback=callback,
+ pdf_cls=Pdf,
+ layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
**kwargs
)
@@ -185,7 +181,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
-
+
for txt, poss in raw_sections:
sections.append(txt + poss)
@@ -226,7 +222,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
-
# Remove 'Contents' part
remove_contents_table(sections, eng)
@@ -234,7 +229,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
bull = bullets_category(sections)
res = tree_merge(bull, sections, 2)
-
if not res:
callback(0.99, "No chunk parsed out.")
@@ -243,9 +237,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
# chunks = hierarchical_merge(bull, sections, 5)
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
+
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
+
chunk(sys.argv[1], callback=dummy)
diff --git a/rag/app/manual.py b/rag/app/manual.py
index 969859e3e..0c85e8949 100644
--- a/rag/app/manual.py
+++ b/rag/app/manual.py
@@ -20,15 +20,17 @@ import re
from common.constants import ParserType
from io import BytesIO
-from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level, attach_media_context
+from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, \
+ docx_question_level, attach_media_context
from common.token_utils import num_tokens_from_string
from deepdoc.parser import PdfParser, DocxParser
-from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
+from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper, vision_figure_parser_docx_wrapper
from docx import Document
from PIL import Image
from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
+
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.MANUAL.value
@@ -129,11 +131,11 @@ class Docx(DocxParser):
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
- if not question_level or question_level > 6: # not a question
+ if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = self.concat_img(last_image, current_image)
- else: # is a question
+ else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
@@ -159,14 +161,14 @@ class Docx(DocxParser):
tbls = []
for tb in self.doc.tables:
- html= "
"
+ html = ""
for r in tb.rows:
html += ""
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
- for j in range(i+1, len(r.cells)):
+ for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
@@ -211,16 +213,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
kwargs.pop("parse_method", None)
kwargs.pop("mineru_llm_name", None)
sections, tbls, pdf_parser = pdf_parser(
- filename = filename,
- binary = binary,
- from_page = from_page,
- to_page = to_page,
- lang = lang,
- callback = callback,
- pdf_cls = Pdf,
- layout_recognizer = layout_recognizer,
+ filename=filename,
+ binary=binary,
+ from_page=from_page,
+ to_page=to_page,
+ lang=lang,
+ callback=callback,
+ pdf_cls=Pdf,
+ layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
- parse_method = "manual",
+ parse_method="manual",
**kwargs
)
@@ -237,10 +239,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if isinstance(poss, str):
poss = pdf_parser.extract_positions(poss)
if poss:
- first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
+ first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
pn = first[0]
if isinstance(pn, list) and pn:
- pn = pn[0] # [pn] -> pn
+ pn = pn[0] # [pn] -> pn
poss[0] = (pn, *first[1:])
return (txt, layoutno, poss)
@@ -289,7 +291,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1,
- [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
+ [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
def tag(pn, left, right, top, bottom):
if pn + left + right + top + bottom == 0:
@@ -312,7 +314,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1:
last_sid = sec_id
- tbls = vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
+ tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
@@ -325,7 +327,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
- tbls = vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
+ tbls = vision_figure_parser_docx_wrapper(sections=ti_list, tbls=tbls, callback=callback, **kwargs)
res = tokenize_table(tbls, doc, eng)
for text, image in ti_list:
d = copy.deepcopy(doc)
diff --git a/rag/app/naive.py b/rag/app/naive.py
index 1046a8f62..5f269d1c5 100644
--- a/rag/app/naive.py
+++ b/rag/app/naive.py
@@ -31,16 +31,20 @@ from common.token_utils import num_tokens_from_string
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html
-from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
-from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
+from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, \
+ PdfParser, TxtParser
+from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_docx_wrapper, \
+ vision_figure_parser_pdf_wrapper
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from deepdoc.parser.docling_parser import DoclingParser
from deepdoc.parser.tcadp_parser import TCADPParser
from common.parser_config_utils import normalize_layout_recognizer
-from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
+from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, \
+ tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
-def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
+def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
+ **kwargs):
callback = callback
binary = binary
pdf_parser = pdf_cls() if pdf_cls else Pdf()
@@ -58,17 +62,17 @@ def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
def by_mineru(
- filename,
- binary=None,
- from_page=0,
- to_page=100000,
- lang="Chinese",
- callback=None,
- pdf_cls=None,
- parse_method: str = "raw",
- mineru_llm_name: str | None = None,
- tenant_id: str | None = None,
- **kwargs,
+ filename,
+ binary=None,
+ from_page=0,
+ to_page=100000,
+ lang="Chinese",
+ callback=None,
+ pdf_cls=None,
+ parse_method: str = "raw",
+ mineru_llm_name: str | None = None,
+ tenant_id: str | None = None,
+ **kwargs,
):
pdf_parser = None
if tenant_id:
@@ -106,7 +110,8 @@ def by_mineru(
return None, None, None
-def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
+def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
+ **kwargs):
pdf_parser = DoclingParser()
parse_method = kwargs.get("parse_method", "raw")
@@ -125,7 +130,7 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
return sections, tables, pdf_parser
-def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
+def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None, **kwargs):
tcadp_parser = TCADPParser()
if not tcadp_parser.check_installation():
@@ -168,10 +173,10 @@ def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=No
PARSERS = {
- "deepdoc": by_deepdoc,
- "mineru": by_mineru,
- "docling": by_docling,
- "tcadp": by_tcadp,
+ "deepdoc": by_deepdoc,
+ "mineru": by_mineru,
+ "docling": by_docling,
+ "tcadp": by_tcadp,
"plaintext": by_plaintext, # default
}
@@ -264,7 +269,7 @@ class Docx(DocxParser):
# Find the nearest heading paragraph in reverse order
nearest_title = None
- for i in range(len(blocks)-1, -1, -1):
+ for i in range(len(blocks) - 1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
@@ -293,7 +298,7 @@ class Docx(DocxParser):
# Find all parent headings, allowing cross-level search
while current_level > 1:
found = False
- for i in range(len(blocks)-1, -1, -1):
+ for i in range(len(blocks) - 1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
@@ -426,7 +431,8 @@ class Docx(DocxParser):
try:
if inline_images:
- result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64))
+ result = mammoth.convert_to_html(docx_file,
+ convert_image=mammoth.images.img_element(_convert_image_to_base64))
else:
result = mammoth.convert_to_html(docx_file)
@@ -621,6 +627,7 @@ class Markdown(MarkdownParser):
return sections, tbls, section_images
return sections, tbls
+
def load_from_xml_v2(baseURI, rels_item_xml):
"""
Return |_SerializedRelationships| instance loaded with the
@@ -636,6 +643,7 @@ def load_from_xml_v2(baseURI, rels_item_xml):
srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
return srels
+
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, excel, txt.
@@ -651,7 +659,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
- child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
+ child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode(
+ 'latin1').decode('utf-8')
cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
if cust_child_deli:
@@ -685,7 +694,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
# Recursively chunk each embedded file and collect results
for embed_filename, embed_bytes in embeds:
try:
- sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False, **kwargs) or []
+ sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False,
+ **kwargs) or []
embed_res.extend(sub_res)
except Exception as e:
if callback:
@@ -704,7 +714,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
except Exception as e:
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
- sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
+ sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False,
+ **kwargs)
url_res.extend(sub_url_res)
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
@@ -747,14 +758,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser(
- filename = filename,
- binary = binary,
- from_page = from_page,
- to_page = to_page,
- lang = lang,
- callback = callback,
- layout_recognizer = layout_recognizer,
- mineru_llm_name = parser_model_name,
+ filename=filename,
+ binary=binary,
+ from_page=from_page,
+ to_page=to_page,
+ lang=lang,
+ callback=callback,
+ layout_recognizer=layout_recognizer,
+ mineru_llm_name=parser_model_name,
**kwargs
)
@@ -846,9 +857,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
else:
section_images = [None] * len(sections)
section_images[idx] = combined_image
- markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
+ markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=[
+ ((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
boosted_figures = markdown_vision_parser(callback=callback)
- sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]), sections[idx][1])
+ sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]),
+ sections[idx][1])
else:
logging.warning("No visual model detected. Skipping figure parsing enhancement.")
@@ -945,7 +958,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
has_images = merged_images and any(img is not None for img in merged_images)
if has_images:
- res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images, child_delimiters_pattern=child_deli))
+ res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images,
+ child_delimiters_pattern=child_deli))
else:
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser, child_delimiters_pattern=child_deli))
else:
@@ -955,10 +969,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if section_images:
chunks, images = naive_merge_with_images(sections, section_images,
- int(parser_config.get(
- "chunk_token_num", 128)), parser_config.get(
- "delimiter", "\n!?。;!?"))
- res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
+ int(parser_config.get(
+ "chunk_token_num", 128)), parser_config.get(
+ "delimiter", "\n!?。;!?"))
+ res.extend(
+ tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
else:
chunks = naive_merge(
sections, int(parser_config.get(
@@ -993,7 +1008,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
diff --git a/rag/app/one.py b/rag/app/one.py
index cc34c0779..fe3a25430 100644
--- a/rag/app/one.py
+++ b/rag/app/one.py
@@ -26,6 +26,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
+
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
@@ -95,14 +96,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
sections, tbls, pdf_parser = parser(
- filename = filename,
- binary = binary,
- from_page = from_page,
- to_page = to_page,
- lang = lang,
- callback = callback,
- pdf_cls = Pdf,
- layout_recognizer = layout_recognizer,
+ filename=filename,
+ binary=binary,
+ from_page=from_page,
+ to_page=to_page,
+ lang=lang,
+ callback=callback,
+ pdf_cls=Pdf,
+ layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
**kwargs
)
@@ -112,9 +113,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
-
+
callback(0.8, "Finish parsing.")
-
+
for (img, rows), poss in tbls:
if not rows:
continue
@@ -172,7 +173,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
diff --git a/rag/app/paper.py b/rag/app/paper.py
index 22b57738c..4317c7a1d 100644
--- a/rag/app/paper.py
+++ b/rag/app/paper.py
@@ -20,7 +20,8 @@ import re
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
from common.constants import ParserType
-from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks, attach_media_context
+from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, \
+ tokenize_chunks, attach_media_context
from deepdoc.parser import PdfParser
import numpy as np
from rag.app.naive import by_plaintext, PARSERS
@@ -66,7 +67,7 @@ class Pdf(PdfParser):
# clean mess
if column_width < self.page_images[0].size[0] / zoomin / 2:
logging.debug("two_column................... {} {}".format(column_width,
- self.page_images[0].size[0] / zoomin / 2))
+ self.page_images[0].size[0] / zoomin / 2))
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
for b in self.boxes:
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
@@ -89,7 +90,7 @@ class Pdf(PdfParser):
title = ""
authors = []
i = 0
- while i < min(32, len(self.boxes)-1):
+ while i < min(32, len(self.boxes) - 1):
b = self.boxes[i]
i += 1
if b.get("layoutno", "").find("title") >= 0:
@@ -190,8 +191,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"tables": tables
}
- tbls=paper["tables"]
- tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
+ tbls = paper["tables"]
+ tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
paper["tables"] = tbls
else:
raise NotImplementedError("file type not supported yet(pdf supported)")
@@ -329,6 +330,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
+
chunk(sys.argv[1], callback=dummy)
diff --git a/rag/app/picture.py b/rag/app/picture.py
index bc93ab279..c60b7e85e 100644
--- a/rag/app/picture.py
+++ b/rag/app/picture.py
@@ -51,7 +51,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
}
)
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
- ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
+ ans = asyncio.run(
+ cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
ans += "\n" + ans
tokenize(doc, ans, eng)
diff --git a/rag/app/presentation.py b/rag/app/presentation.py
index e4a093634..26c08183e 100644
--- a/rag/app/presentation.py
+++ b/rag/app/presentation.py
@@ -249,7 +249,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
+
def dummy(a, b):
pass
+
chunk(sys.argv[1], callback=dummy)
diff --git a/rag/app/qa.py b/rag/app/qa.py
index ecf60ec4f..a31240bd3 100644
--- a/rag/app/qa.py
+++ b/rag/app/qa.py
@@ -102,9 +102,9 @@ class Pdf(PdfParser):
self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
tbls = self._extract_table_figure(True, zoomin, True, True)
- #self._naive_vertical_merge()
+ # self._naive_vertical_merge()
# self._concat_downward()
- #self._filter_forpages()
+ # self._filter_forpages()
logging.debug("layouts: {}".format(timer() - start))
sections = [b["text"] for b in self.boxes]
bull_x0_list = []
@@ -114,12 +114,14 @@ class Pdf(PdfParser):
qai_list = []
last_q, last_a, last_tag = '', '', ''
last_index = -1
- last_box = {'text':''}
+ last_box = {'text': ''}
last_bull = None
+
def sort_key(element):
tbls_pn = element[1][0][0]
tbls_top = element[1][0][3]
return tbls_pn, tbls_top
+
tbls.sort(key=sort_key)
tbl_index = 0
last_pn, last_bottom = 0, 0
@@ -133,28 +135,32 @@ class Pdf(PdfParser):
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
if not has_bull: # No question bullet
if not last_q:
- if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
+ if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
tbl_index += 1
continue
else:
sum_tag = line_tag
sum_section = section
- while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
- and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer
+ while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
+ and ((tbl_pn == line_pn and tbl_top <= line_top) or (
+ tbl_pn < line_pn)): # add image at the middle of current answer
sum_tag = f'{tbl_tag}{sum_tag}'
sum_section = f'{tbl_text}{sum_section}'
tbl_index += 1
- tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
+ tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
+ tbl_index)
last_a = f'{last_a}{sum_section}'
last_tag = f'{last_tag}{sum_tag}'
else:
if last_q:
- while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
- and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer
+ while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
+ and ((tbl_pn == line_pn and tbl_top <= line_top) or (
+ tbl_pn < line_pn)): # add image at the end of last answer
last_tag = f'{last_tag}{tbl_tag}'
last_a = f'{last_a}{tbl_text}'
tbl_index += 1
- tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
+ tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
+ tbl_index)
image, poss = self.crop(last_tag, need_position=True)
qai_list.append((last_q, last_a, image, poss))
last_q, last_a, last_tag = '', '', ''
@@ -171,7 +177,7 @@ class Pdf(PdfParser):
def get_tbls_info(self, tbls, tbl_index):
if tbl_index >= len(tbls):
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
- tbl_pn = tbls[tbl_index][1][0][0]+1
+ tbl_pn = tbls[tbl_index][1][0][0] + 1
tbl_left = tbls[tbl_index][1][0][1]
tbl_right = tbls[tbl_index][1][0][2]
tbl_top = tbls[tbl_index][1][0][3]
@@ -210,11 +216,11 @@ class Docx(DocxParser):
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
- if not question_level or question_level > 6: # not a question
+ if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = concat_img(last_image, current_image)
- else: # is a question
+ else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
@@ -240,14 +246,14 @@ class Docx(DocxParser):
tbls = []
for tb in self.doc.tables:
- html= ""
+ html = ""
for r in tb.rows:
html += ""
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
- for j in range(i+1, len(r.cells)):
+ for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
@@ -356,7 +362,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if question:
answer += "\n" + lines[i]
else:
- fails.append(str(i+1))
+ fails.append(str(i + 1))
elif len(arr) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
@@ -429,13 +435,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if not code_block:
question_level, question = mdQuestionLevel(line)
- if not question_level or question_level > 6: # not a question
+ if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{line}'
- else: # is a question
+ else: # is a question
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
- res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
+ res.append(beAdoc(deepcopy(doc), sum_question,
+ markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = ''
i = question_level
@@ -447,13 +454,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
- res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
+ res.append(beAdoc(deepcopy(doc), sum_question,
+ markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res
elif re.search(r"\.docx$", filename, re.IGNORECASE):
docx_parser = Docx()
qai_list, tbls = docx_parser(filename, binary,
- from_page=0, to_page=10000, callback=callback)
+ from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng)
for i, (q, a, image) in enumerate(qai_list):
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
@@ -466,6 +474,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
+
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
diff --git a/rag/app/resume.py b/rag/app/resume.py
index fc6bc6556..b022f81b3 100644
--- a/rag/app/resume.py
+++ b/rag/app/resume.py
@@ -64,7 +64,8 @@ def remote_call(filename, binary):
del resume[k]
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
- "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
+ "updated_at": datetime.datetime.now().strftime(
+ "%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume)
return resume
except Exception:
@@ -171,6 +172,9 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if __name__ == "__main__":
import sys
+
def dummy(a, b):
pass
+
+
chunk(sys.argv[1], callback=dummy)
diff --git a/rag/app/table.py b/rag/app/table.py
index bb6a4d007..4ffbee367 100644
--- a/rag/app/table.py
+++ b/rag/app/table.py
@@ -51,14 +51,15 @@ class Excel(ExcelParser):
tables = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
- images = Excel._extract_images_from_worksheet(ws,sheetname=sheetname)
+ images = Excel._extract_images_from_worksheet(ws, sheetname=sheetname)
if images:
- image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, **kwargs)
+ image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback,
+ **kwargs)
if image_descriptions and len(image_descriptions) == len(images):
for i, bf in enumerate(image_descriptions):
images[i]["image_description"] = "\n".join(bf[0][1])
for img in images:
- if (img["span_type"] == "single_cell"and img.get("image_description")):
+ if (img["span_type"] == "single_cell" and img.get("image_description")):
pending_cell_images.append(img)
else:
flow_images.append(img)
@@ -113,16 +114,17 @@ class Excel(ExcelParser):
tables.append(
(
(
- img["image"], # Image.Image
- [img["image_description"]] # description list (must be list)
+ img["image"], # Image.Image
+ [img["image_description"]] # description list (must be list)
),
[
- (0, 0, 0, 0, 0) # dummy position
+ (0, 0, 0, 0, 0) # dummy position
]
)
)
- callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
- return res,tables
+ callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
+ return res, tables
def _parse_headers(self, ws, rows):
if len(rows) == 0:
@@ -315,14 +317,15 @@ def trans_bool(s):
def column_data_type(arr):
arr = list(arr)
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
- trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
+ trans = {t: f for f, t in
+ [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
float_flag = False
for a in arr:
if a is None:
continue
if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
counts["int"] += 1
- if int(str(a)) > 2**63 - 1:
+ if int(str(a)) > 2 ** 63 - 1:
float_flag = True
break
elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
@@ -370,7 +373,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
- dfs,tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs)
+ dfs, tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs)
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
@@ -389,7 +392,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
continue
rows.append(row)
- callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
+ callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
elif re.search(r"\.csv$", filename, re.IGNORECASE):
@@ -406,7 +410,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
fails = []
rows = []
- for i, row in enumerate(all_rows[1 + from_page : 1 + to_page]):
+ for i, row in enumerate(all_rows[1 + from_page: 1 + to_page]):
if len(row) != len(headers):
fails.append(str(i + from_page))
continue
@@ -415,7 +419,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
callback(
0.3,
(f"Extract records: {from_page}~{from_page + len(rows)}" +
- (f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))
+ (f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))
)
dfs = [pd.DataFrame(rows, columns=headers)]
@@ -445,7 +449,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
df[clmns[j]] = cln
if ty == "text":
txts.extend([str(c) for c in cln if c])
- clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))]
+ clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
+ range(len(clmns))]
eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows():
@@ -477,7 +482,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
+
chunk(sys.argv[1], callback=dummy)
diff --git a/rag/app/tag.py b/rag/app/tag.py
index fda91f1a3..8e516a75f 100644
--- a/rag/app/tag.py
+++ b/rag/app/tag.py
@@ -141,17 +141,20 @@ def label_question(question, kbs):
if not tag_kbs:
return tags
tags = settings.retriever.tag_query(question,
- list(set([kb.tenant_id for kb in tag_kbs])),
- tag_kb_ids,
- all_tags,
- kb.parser_config.get("topn_tags", 3)
- )
+ list(set([kb.tenant_id for kb in tag_kbs])),
+ tag_kb_ids,
+ all_tags,
+ kb.parser_config.get("topn_tags", 3)
+ )
return tags
if __name__ == "__main__":
import sys
+
def dummy(prog=None, msg=""):
pass
- chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
\ No newline at end of file
+
+
+ chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py
index 3cebff329..de269320d 100644
--- a/rag/llm/tts_model.py
+++ b/rag/llm/tts_model.py
@@ -263,7 +263,7 @@ class SparkTTS(Base):
raise Exception(error)
def on_close(self, ws, close_status_code, close_msg):
- self.audio_queue.put(None) # 放入 None 作为结束标志
+ self.audio_queue.put(None) # None is terminator
def on_open(self, ws):
def run(*args):
diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py
index 641fdf5ab..76ba6f4e5 100644
--- a/rag/nlp/__init__.py
+++ b/rag/nlp/__init__.py
@@ -273,7 +273,7 @@ def tokenize(d, txt, eng):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
-def split_with_pattern(d, pattern:str, content:str, eng) -> list:
+def split_with_pattern(d, pattern: str, content: str, eng) -> list:
docs = []
txts = [txt for txt in re.split(r"(%s)" % pattern, content, flags=re.DOTALL)]
for j in range(0, len(txts), 2):
@@ -281,7 +281,7 @@ def split_with_pattern(d, pattern:str, content:str, eng) -> list:
if not txt:
continue
if j + 1 < len(txts):
- txt += txts[j+1]
+ txt += txts[j + 1]
dd = copy.deepcopy(d)
tokenize(dd, txt, eng)
docs.append(dd)
@@ -304,7 +304,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern=
except NotImplementedError:
pass
else:
- add_positions(d, [[ii]*5])
+ add_positions(d, [[ii] * 5])
if child_delimiters_pattern:
d["mom_with_weight"] = ck
@@ -325,7 +325,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images, child_delimiters_patte
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
- add_positions(d, [[ii]*5])
+ add_positions(d, [[ii] * 5])
if child_delimiters_pattern:
d["mom_with_weight"] = ck
res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng))
@@ -658,7 +658,8 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
if "content_ltks" in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(combined)
if "content_sm_ltks" in ck:
- ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
+ ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
+ ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
if positioned_indices:
chunks[:] = [chunks[i] for i in ordered_indices]
@@ -764,8 +765,8 @@ def not_title(txt):
return True
return re.search(r"[,;,。;!!]", txt)
-def tree_merge(bull, sections, depth):
+def tree_merge(bull, sections, depth):
if not sections or bull < 0:
return sections
if isinstance(sections[0], type("")):
@@ -777,16 +778,17 @@ def tree_merge(bull, sections, depth):
def get_level(bull, section):
text, layout = section
- text = re.sub(r"\u3000", " ", text).strip()
+ text = re.sub(r"\u3000", " ", text).strip()
for i, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, text.strip()):
- return i+1, text
+ return i + 1, text
else:
if re.search(r"(title|head)", layout) and not not_title(text):
- return len(BULLET_PATTERN[bull])+1, text
+ return len(BULLET_PATTERN[bull]) + 1, text
else:
- return len(BULLET_PATTERN[bull])+2, text
+ return len(BULLET_PATTERN[bull]) + 2, text
+
level_set = set()
lines = []
for section in sections:
@@ -812,8 +814,8 @@ def tree_merge(bull, sections, depth):
return [element for element in root.get_tree() if element]
-def hierarchical_merge(bull, sections, depth):
+def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0:
return []
if isinstance(sections[0], type("")):
@@ -922,10 +924,10 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
- if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
+ if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
- t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
+ t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
@@ -957,7 +959,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
return cks
for sec, pos in sections:
- add_chunk("\n"+sec, pos)
+ add_chunk("\n" + sec, pos)
return cks
@@ -978,10 +980,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
- if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
+ if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
- t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
+ t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
@@ -1025,9 +1027,9 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if isinstance(text, tuple):
text_str = text[0]
text_pos = text[1] if len(text) > 1 else ""
- add_chunk("\n"+text_str, image, text_pos)
+ add_chunk("\n" + text_str, image, text_pos)
else:
- add_chunk("\n"+text, image)
+ add_chunk("\n" + text, image)
return cks, result_images
@@ -1042,7 +1044,7 @@ def docx_question_level(p, bull=-1):
for j, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, txt):
return j + 1, txt
- return len(BULLET_PATTERN[bull])+1, txt
+ return len(BULLET_PATTERN[bull]) + 1, txt
def concat_img(img1, img2):
@@ -1211,7 +1213,7 @@ class Node:
child = node.get_children()
if level == 0 and texts:
- tree_list.append("\n".join(titles+texts))
+ tree_list.append("\n".join(titles + texts))
# Titles within configured depth are accumulated into the current path
if 1 <= level <= self.depth:
diff --git a/rag/nlp/query.py b/rag/nlp/query.py
index 1cb2f4071..402b240fe 100644
--- a/rag/nlp/query.py
+++ b/rag/nlp/query.py
@@ -205,11 +205,11 @@ class FulltextQueryer(QueryBase):
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
- s += v #* dtwt[k]
+ s += v # * dtwt[k]
q = 1e-9
for k, v in qtwt.items():
- q += v #* v
- return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
+ q += v # * v
+ return s / q # math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
@@ -232,4 +232,5 @@ class FulltextQueryer(QueryBase):
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
- {"minimum_should_match": min(3, len(keywords) / 10), "original_query": " ".join(origin_keywords)})
+ {"minimum_should_match": min(3, len(keywords) / 10),
+ "original_query": " ".join(origin_keywords)})
diff --git a/rag/nlp/search.py b/rag/nlp/search.py
index 8748cbd97..01f55c9ef 100644
--- a/rag/nlp/search.py
+++ b/rag/nlp/search.py
@@ -66,7 +66,8 @@ class Dealer:
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
- for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
+ for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
+ "removed_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
@@ -141,7 +142,8 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
- orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
+ orderBy, offset, limit, idx_names, kb_ids,
+ rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
@@ -218,8 +220,9 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]):
- chunk_v[i] = [0.0]*len(ans_v[0])
- logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
+ chunk_v[i] = [0.0] * len(ans_v[0])
+ logging.warning(
+ "The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
@@ -273,7 +276,7 @@ class Dealer:
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
- q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
+ q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD):
@@ -286,8 +289,8 @@ class Dealer:
if denor == 0:
rank_fea.append(0)
else:
- rank_fea.append(nor/np.sqrt(denor)/q_denor)
- return np.array(rank_fea)*10. + pageranks
+ rank_fea.append(nor / np.sqrt(denor) / q_denor)
+ return np.array(rank_fea) * 10. + pageranks
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
@@ -358,21 +361,21 @@ class Dealer:
rag_tokenizer.tokenize(inst).split())
def retrieval(
- self,
- question,
- embd_mdl,
- tenant_ids,
- kb_ids,
- page,
- page_size,
- similarity_threshold=0.2,
- vector_similarity_weight=0.3,
- top=1024,
- doc_ids=None,
- aggs=True,
- rerank_mdl=None,
- highlight=False,
- rank_feature: dict | None = {PAGERANK_FLD: 10},
+ self,
+ question,
+ embd_mdl,
+ tenant_ids,
+ kb_ids,
+ page,
+ page_size,
+ similarity_threshold=0.2,
+ vector_similarity_weight=0.3,
+ top=1024,
+ doc_ids=None,
+ aggs=True,
+ rerank_mdl=None,
+ highlight=False,
+ rank_feature: dict | None = {PAGERANK_FLD: 10},
):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
@@ -395,7 +398,8 @@ class Dealer:
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
- sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
+ sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
+ rank_feature=rank_feature)
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(
@@ -558,13 +562,14 @@ class Dealer:
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id)
- match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
+ match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
+ keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
- tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
+ tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True
@@ -580,11 +585,11 @@ class Dealer:
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
- tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
+ tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
- def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
+ def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
if not chunks:
return []
idx_nms = [index_name(tid) for tid in tenant_ids]
@@ -594,9 +599,10 @@ class Dealer:
ranks[ck["doc_id"]] = 0
ranks[ck["doc_id"]] += ck["similarity"]
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
- doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
+ doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
kb_ids = [doc_id2kb_id[doc_id]]
- es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
+ es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
+ OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
@@ -608,7 +614,7 @@ class Dealer:
if not toc:
return chunks
- ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2))
+ ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2))
if not ids:
return chunks
@@ -644,9 +650,9 @@ class Dealer:
break
chunks.append(d)
- return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
+ return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]
- def retrieval_by_children(self, chunks:list[dict], tenant_ids:list[str]):
+ def retrieval_by_children(self, chunks: list[dict], tenant_ids: list[str]):
if not chunks:
return []
idx_nms = [index_name(tid) for tid in tenant_ids]
@@ -692,4 +698,4 @@ class Dealer:
break
chunks.append(d)
- return sorted(chunks, key=lambda x:x["similarity"]*-1)
+ return sorted(chunks, key=lambda x: x["similarity"] * -1)
diff --git a/rag/nlp/surname.py b/rag/nlp/surname.py
index 0b6a682d4..39c28be95 100644
--- a/rag/nlp/surname.py
+++ b/rag/nlp/surname.py
@@ -14,129 +14,131 @@
# limitations under the License.
#
-m = set(["赵","钱","孙","李",
-"周","吴","郑","王",
-"冯","陈","褚","卫",
-"蒋","沈","韩","杨",
-"朱","秦","尤","许",
-"何","吕","施","张",
-"孔","曹","严","华",
-"金","魏","陶","姜",
-"戚","谢","邹","喻",
-"柏","水","窦","章",
-"云","苏","潘","葛",
-"奚","范","彭","郎",
-"鲁","韦","昌","马",
-"苗","凤","花","方",
-"俞","任","袁","柳",
-"酆","鲍","史","唐",
-"费","廉","岑","薛",
-"雷","贺","倪","汤",
-"滕","殷","罗","毕",
-"郝","邬","安","常",
-"乐","于","时","傅",
-"皮","卞","齐","康",
-"伍","余","元","卜",
-"顾","孟","平","黄",
-"和","穆","萧","尹",
-"姚","邵","湛","汪",
-"祁","毛","禹","狄",
-"米","贝","明","臧",
-"计","伏","成","戴",
-"谈","宋","茅","庞",
-"熊","纪","舒","屈",
-"项","祝","董","梁",
-"杜","阮","蓝","闵",
-"席","季","麻","强",
-"贾","路","娄","危",
-"江","童","颜","郭",
-"梅","盛","林","刁",
-"钟","徐","邱","骆",
-"高","夏","蔡","田",
-"樊","胡","凌","霍",
-"虞","万","支","柯",
-"昝","管","卢","莫",
-"经","房","裘","缪",
-"干","解","应","宗",
-"丁","宣","贲","邓",
-"郁","单","杭","洪",
-"包","诸","左","石",
-"崔","吉","钮","龚",
-"程","嵇","邢","滑",
-"裴","陆","荣","翁",
-"荀","羊","於","惠",
-"甄","曲","家","封",
-"芮","羿","储","靳",
-"汲","邴","糜","松",
-"井","段","富","巫",
-"乌","焦","巴","弓",
-"牧","隗","山","谷",
-"车","侯","宓","蓬",
-"全","郗","班","仰",
-"秋","仲","伊","宫",
-"宁","仇","栾","暴",
-"甘","钭","厉","戎",
-"祖","武","符","刘",
-"景","詹","束","龙",
-"叶","幸","司","韶",
-"郜","黎","蓟","薄",
-"印","宿","白","怀",
-"蒲","邰","从","鄂",
-"索","咸","籍","赖",
-"卓","蔺","屠","蒙",
-"池","乔","阴","鬱",
-"胥","能","苍","双",
-"闻","莘","党","翟",
-"谭","贡","劳","逄",
-"姬","申","扶","堵",
-"冉","宰","郦","雍",
-"郤","璩","桑","桂",
-"濮","牛","寿","通",
-"边","扈","燕","冀",
-"郏","浦","尚","农",
-"温","别","庄","晏",
-"柴","瞿","阎","充",
-"慕","连","茹","习",
-"宦","艾","鱼","容",
-"向","古","易","慎",
-"戈","廖","庾","终",
-"暨","居","衡","步",
-"都","耿","满","弘",
-"匡","国","文","寇",
-"广","禄","阙","东",
-"欧","殳","沃","利",
-"蔚","越","夔","隆",
-"师","巩","厍","聂",
-"晁","勾","敖","融",
-"冷","訾","辛","阚",
-"那","简","饶","空",
-"曾","母","沙","乜",
-"养","鞠","须","丰",
-"巢","关","蒯","相",
-"查","后","荆","红",
-"游","竺","权","逯",
-"盖","益","桓","公",
-"兰","原","乞","西","阿","肖","丑","位","曽","巨","德","代","圆","尉","仵","纳","仝","脱","丘","但","展","迪","付","覃","晗","特","隋","苑","奥","漆","谌","郄","练","扎","邝","渠","信","门","陳","化","原","密","泮","鹿","赫",
-"万俟","司马","上官","欧阳",
-"夏侯","诸葛","闻人","东方",
-"赫连","皇甫","尉迟","公羊",
-"澹台","公冶","宗政","濮阳",
-"淳于","单于","太叔","申屠",
-"公孙","仲孙","轩辕","令狐",
-"钟离","宇文","长孙","慕容",
-"鲜于","闾丘","司徒","司空",
-"亓官","司寇","仉督","子车",
-"颛孙","端木","巫马","公西",
-"漆雕","乐正","壤驷","公良",
-"拓跋","夹谷","宰父","榖梁",
-"晋","楚","闫","法","汝","鄢","涂","钦",
-"段干","百里","东郭","南门",
-"呼延","归","海","羊舌","微","生",
-"岳","帅","缑","亢","况","后","有","琴",
-"梁丘","左丘","东门","西门",
-"商","牟","佘","佴","伯","赏","南宫",
-"墨","哈","谯","笪","年","爱","阳","佟",
-"第五","言","福"])
+m = set(["赵", "钱", "孙", "李",
+ "周", "吴", "郑", "王",
+ "冯", "陈", "褚", "卫",
+ "蒋", "沈", "韩", "杨",
+ "朱", "秦", "尤", "许",
+ "何", "吕", "施", "张",
+ "孔", "曹", "严", "华",
+ "金", "魏", "陶", "姜",
+ "戚", "谢", "邹", "喻",
+ "柏", "水", "窦", "章",
+ "云", "苏", "潘", "葛",
+ "奚", "范", "彭", "郎",
+ "鲁", "韦", "昌", "马",
+ "苗", "凤", "花", "方",
+ "俞", "任", "袁", "柳",
+ "酆", "鲍", "史", "唐",
+ "费", "廉", "岑", "薛",
+ "雷", "贺", "倪", "汤",
+ "滕", "殷", "罗", "毕",
+ "郝", "邬", "安", "常",
+ "乐", "于", "时", "傅",
+ "皮", "卞", "齐", "康",
+ "伍", "余", "元", "卜",
+ "顾", "孟", "平", "黄",
+ "和", "穆", "萧", "尹",
+ "姚", "邵", "湛", "汪",
+ "祁", "毛", "禹", "狄",
+ "米", "贝", "明", "臧",
+ "计", "伏", "成", "戴",
+ "谈", "宋", "茅", "庞",
+ "熊", "纪", "舒", "屈",
+ "项", "祝", "董", "梁",
+ "杜", "阮", "蓝", "闵",
+ "席", "季", "麻", "强",
+ "贾", "路", "娄", "危",
+ "江", "童", "颜", "郭",
+ "梅", "盛", "林", "刁",
+ "钟", "徐", "邱", "骆",
+ "高", "夏", "蔡", "田",
+ "樊", "胡", "凌", "霍",
+ "虞", "万", "支", "柯",
+ "昝", "管", "卢", "莫",
+ "经", "房", "裘", "缪",
+ "干", "解", "应", "宗",
+ "丁", "宣", "贲", "邓",
+ "郁", "单", "杭", "洪",
+ "包", "诸", "左", "石",
+ "崔", "吉", "钮", "龚",
+ "程", "嵇", "邢", "滑",
+ "裴", "陆", "荣", "翁",
+ "荀", "羊", "於", "惠",
+ "甄", "曲", "家", "封",
+ "芮", "羿", "储", "靳",
+ "汲", "邴", "糜", "松",
+ "井", "段", "富", "巫",
+ "乌", "焦", "巴", "弓",
+ "牧", "隗", "山", "谷",
+ "车", "侯", "宓", "蓬",
+ "全", "郗", "班", "仰",
+ "秋", "仲", "伊", "宫",
+ "宁", "仇", "栾", "暴",
+ "甘", "钭", "厉", "戎",
+ "祖", "武", "符", "刘",
+ "景", "詹", "束", "龙",
+ "叶", "幸", "司", "韶",
+ "郜", "黎", "蓟", "薄",
+ "印", "宿", "白", "怀",
+ "蒲", "邰", "从", "鄂",
+ "索", "咸", "籍", "赖",
+ "卓", "蔺", "屠", "蒙",
+ "池", "乔", "阴", "鬱",
+ "胥", "能", "苍", "双",
+ "闻", "莘", "党", "翟",
+ "谭", "贡", "劳", "逄",
+ "姬", "申", "扶", "堵",
+ "冉", "宰", "郦", "雍",
+ "郤", "璩", "桑", "桂",
+ "濮", "牛", "寿", "通",
+ "边", "扈", "燕", "冀",
+ "郏", "浦", "尚", "农",
+ "温", "别", "庄", "晏",
+ "柴", "瞿", "阎", "充",
+ "慕", "连", "茹", "习",
+ "宦", "艾", "鱼", "容",
+ "向", "古", "易", "慎",
+ "戈", "廖", "庾", "终",
+ "暨", "居", "衡", "步",
+ "都", "耿", "满", "弘",
+ "匡", "国", "文", "寇",
+ "广", "禄", "阙", "东",
+ "欧", "殳", "沃", "利",
+ "蔚", "越", "夔", "隆",
+ "师", "巩", "厍", "聂",
+ "晁", "勾", "敖", "融",
+ "冷", "訾", "辛", "阚",
+ "那", "简", "饶", "空",
+ "曾", "母", "沙", "乜",
+ "养", "鞠", "须", "丰",
+ "巢", "关", "蒯", "相",
+ "查", "后", "荆", "红",
+ "游", "竺", "权", "逯",
+ "盖", "益", "桓", "公",
+ "兰", "原", "乞", "西", "阿", "肖", "丑", "位", "曽", "巨", "德", "代", "圆", "尉", "仵", "纳", "仝", "脱",
+ "丘", "但", "展", "迪", "付", "覃", "晗", "特", "隋", "苑", "奥", "漆", "谌", "郄", "练", "扎", "邝", "渠",
+ "信", "门", "陳", "化", "原", "密", "泮", "鹿", "赫",
+ "万俟", "司马", "上官", "欧阳",
+ "夏侯", "诸葛", "闻人", "东方",
+ "赫连", "皇甫", "尉迟", "公羊",
+ "澹台", "公冶", "宗政", "濮阳",
+ "淳于", "单于", "太叔", "申屠",
+ "公孙", "仲孙", "轩辕", "令狐",
+ "钟离", "宇文", "长孙", "慕容",
+ "鲜于", "闾丘", "司徒", "司空",
+ "亓官", "司寇", "仉督", "子车",
+ "颛孙", "端木", "巫马", "公西",
+ "漆雕", "乐正", "壤驷", "公良",
+ "拓跋", "夹谷", "宰父", "榖梁",
+ "晋", "楚", "闫", "法", "汝", "鄢", "涂", "钦",
+ "段干", "百里", "东郭", "南门",
+ "呼延", "归", "海", "羊舌", "微", "生",
+ "岳", "帅", "缑", "亢", "况", "后", "有", "琴",
+ "梁丘", "左丘", "东门", "西门",
+ "商", "牟", "佘", "佴", "伯", "赏", "南宫",
+ "墨", "哈", "谯", "笪", "年", "爱", "阳", "佟",
+ "第五", "言", "福"])
-def isit(n):return n.strip() in m
+def isit(n): return n.strip() in m
diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py
index 28ed585ee..4ab410129 100644
--- a/rag/nlp/term_weight.py
+++ b/rag/nlp/term_weight.py
@@ -1,4 +1,4 @@
- #
+#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -108,13 +108,14 @@ class Dealer:
if re.match(p, t):
tk = "#"
break
- #tk = re.sub(r"([\+\\-])", r"\\\1", tk)
+ # tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk:
res.append(tk)
return res
def token_merge(self, tks):
- def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
+ def one_term(t):
+ return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
@@ -152,8 +153,8 @@ class Dealer:
tks = []
for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
- re.match(r".*[a-zA-Z]$", t) and tks and \
- self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
+ re.match(r".*[a-zA-Z]$", t) and tks and \
+ self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t
else:
tks.append(t)
@@ -220,14 +221,15 @@ class Dealer:
return 3
- def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
+ def idf(s, N):
+ return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = []
if not preprocess:
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \
- np.array([ner(t) * postag(t) for t in tks])
+ np.array([ner(t) * postag(t) for t in tks])
wts = [s for s in wts]
tw = list(zip(tks, wts))
else:
@@ -236,7 +238,7 @@ class Dealer:
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \
- np.array([ner(t) * postag(t) for t in tt])
+ np.array([ner(t) * postag(t) for t in tt])
wts = [s for s in wts]
tw.extend(zip(tt, wts))
diff --git a/rag/prompts/__init__.py b/rag/prompts/__init__.py
index b8b924b93..a2d991705 100644
--- a/rag/prompts/__init__.py
+++ b/rag/prompts/__init__.py
@@ -3,4 +3,4 @@ from . import generator
__all__ = [name for name in dir(generator)
if not name.startswith('_')]
-globals().update({name: getattr(generator, name) for name in __all__})
\ No newline at end of file
+globals().update({name: getattr(generator, name) for name in __all__})
diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py
index 3ace3a8e7..acc5e582d 100644
--- a/rag/prompts/generator.py
+++ b/rag/prompts/generator.py
@@ -28,17 +28,16 @@ from rag.prompts.template import load_prompt
from common.constants import TAG_FLD
from common.token_utils import encoder, num_tokens_from_string
-
-STOP_TOKEN="<|STOP|>"
-COMPLETE_TASK="complete_task"
+STOP_TOKEN = "<|STOP|>"
+COMPLETE_TASK = "complete_task"
INPUT_UTILIZATION = 0.5
+
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
def chunks_format(reference):
-
return [
{
"id": get_value(chunk, "chunk_id", "id"),
@@ -126,7 +125,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
- cnt += draw_node("URL", ck['url']) if "url" in ck else ""
+ cnt += draw_node("URL", ck['url']) if "url" in ck else ""
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
cnt += draw_node(k, v)
cnt += "\n└── Content:\n"
@@ -173,7 +172,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
-def citation_prompt(user_defined_prompts: dict={}) -> str:
+def citation_prompt(user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
return template.render()
@@ -258,9 +257,11 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
- rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
+ rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,
+ languages=languages)
- ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
+ ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}],
+ {"temperature": 0.2})
ans = re.sub(r"^.*", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
@@ -332,7 +333,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": {
"type": "object",
- "properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
+ "properties": {
+ "answer": {"type": "string", "description": "The final answer to the user's question"}},
"required": ["answer"]
}
}
@@ -341,7 +343,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
name = tool["function"]["name"]
desc[name] = tool
- return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
+ return "\n\n".join([f"## {i + 1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in
+ enumerate(desc.items())])
def form_history(history, limit=-6):
@@ -350,14 +353,14 @@ def form_history(history, limit=-6):
if h["role"] == "system":
continue
role = "USER"
- if h["role"].upper()!= role:
+ if h["role"].upper() != role:
role = "AGENT"
- context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
+ context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content']) > 2048 else '')}"
return context
-
-async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
+async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict],
+ user_defined_prompts: dict = {}):
tools_desc = tool_schema(tools_description)
context = ""
@@ -375,7 +378,8 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd
-async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
+async def next_step_async(chat_mdl, history: list, tools_description: list[dict], task_desc,
+ user_defined_prompts: dict = {}):
if not tools_description:
return "", 0
desc = tool_schema(tools_description)
@@ -396,7 +400,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
return json_str, tk_cnt
-async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
+async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict = {}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@@ -419,7 +423,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
def form_message(system_prompt, user_prompt):
- return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
+ return [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
def structured_output_prompt(schema=None) -> str:
@@ -427,27 +431,29 @@ def structured_output_prompt(schema=None) -> str:
return template.render(schema=schema)
-async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
+async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
- params=json.dumps(params, ensure_ascii=False, indent=2),
- result=result)
+ params=json.dumps(params, ensure_ascii=False, indent=2),
+ result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*", "", ans, flags=re.DOTALL)
-async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
+async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summaries: list[str],
+ user_defined_prompts: dict = {}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
- system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
+ system_prompt = template.render(goal=goal, sub_goal=sub_goal,
+ results=[{"i": i, "content": s} for i, s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*", "", ans, flags=re.DOTALL)
-async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
+async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict:
meta_data_structure = {}
for key, values in meta_data.items():
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
@@ -471,13 +477,13 @@ async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
return {"conditions": []}
-async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
+async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf=None):
from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
- ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
+ ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
res = json_repair.loads(ans)
@@ -488,10 +494,13 @@ async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None
TOC_DETECTION = load_prompt("toc_detection")
-async def detect_table_of_contents(page_1024:list[str], chat_mdl):
+
+
+async def detect_table_of_contents(page_1024: list[str], chat_mdl):
toc_secs = []
for i, sec in enumerate(page_1024[:22]):
- ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
+ ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.",
+ chat_mdl)
if toc_secs and not ans["exists"]:
break
toc_secs.append(sec)
@@ -500,14 +509,17 @@ async def detect_table_of_contents(page_1024:list[str], chat_mdl):
TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
+
+
async def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages:
return []
- return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
+ return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)),
+ "Only JSON please.", chat_mdl)
-async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
+async def toc_index_extractor(toc: list[dict], content: str, chat_mdl):
tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
@@ -529,18 +541,21 @@ async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else."""
- prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
+ prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False,
+ indent=2) + '\nDocument pages:\n' + content
return await gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index")
+
+
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections:
return []
toc_map = {}
for i, it in enumerate(toc_arr):
- k1 = (it["structure"]+it["title"]).replace(" ", "")
+ k1 = (it["structure"] + it["title"]).replace(" ", "")
k2 = it["title"].strip()
if k1 not in toc_map:
toc_map[k1] = []
@@ -558,6 +573,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
toc_arr[j]["indices"].append(i)
all_pathes = []
+
def dfs(start, path):
nonlocal all_pathes
if start >= len(toc_arr):
@@ -565,7 +581,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
all_pathes.append(path)
return
if not toc_arr[start]["indices"]:
- dfs(start+1, path)
+ dfs(start + 1, path)
return
added = False
for j in toc_arr[start]["indices"]:
@@ -574,12 +590,12 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
_path = deepcopy(path)
_path.append((j, start))
added = True
- dfs(start+1, _path)
+ dfs(start + 1, _path)
if not added and path:
all_pathes.append(path)
dfs(0, [])
- path = max(all_pathes, key=lambda x:len(x))
+ path = max(all_pathes, key=lambda x: len(x))
for it in toc_arr:
it["indices"] = []
for j, i in path:
@@ -588,24 +604,24 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
i = 0
while i < len(toc_arr):
- it = toc_arr[i]
+ it = toc_arr[i]
if it["indices"]:
i += 1
continue
- if i>0 and toc_arr[i-1]["indices"]:
- st_i = toc_arr[i-1]["indices"][-1]
+ if i > 0 and toc_arr[i - 1]["indices"]:
+ st_i = toc_arr[i - 1]["indices"][-1]
else:
st_i = 0
e = i + 1
- while e = len(toc_arr):
e = len(sections)
else:
e = toc_arr[e]["indices"][0]
- for j in range(st_i, min(e+1, len(sections))):
+ for j in range(st_i, min(e + 1, len(sections))):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"],
title=it["title"],
@@ -656,11 +672,15 @@ async def toc_transformer(toc_pages, chat_mdl):
toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
+
def clean_toc(arr):
for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
+
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
- if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
+ if_complete = await check_if_toc_transformation_is_complete(toc_content,
+ json.dumps(last_complete, ensure_ascii=False, indent=2),
+ chat_mdl)
clean_toc(last_complete)
if if_complete == "yes":
return last_complete
@@ -682,13 +702,17 @@ async def toc_transformer(toc_pages, chat_mdl):
break
clean_toc(new_complete)
last_complete.extend(new_complete)
- if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
+ if_complete = await check_if_toc_transformation_is_complete(toc_content,
+ json.dumps(last_complete, ensure_ascii=False,
+ indent=2), chat_mdl)
return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels")
-async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
+
+
+async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}):
if not toc_secs:
return []
return await gen_json(
@@ -701,12 +725,15 @@ async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2})
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
+
+
# Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try:
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
- PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
+ PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(
+ text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
@@ -743,7 +770,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)
- input_budget = 1024 if input_budget > 1024 else input_budget
+ input_budget = 1024 if input_budget > 1024 else input_budget
chunk_sections = split_chunks(chunks, input_budget)
titles = []
@@ -798,7 +825,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
if sorted_list:
max_lvl = sorted_list[-1]
merged = []
- for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
+ for _, (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
if prune and toc_item.get("level", "0") >= max_lvl:
continue
merged.append({
@@ -812,12 +839,15 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
-async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
+
+
+async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
import numpy as np
try:
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
- PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
+ PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n" % "\n".join(
+ [json.dumps({"level": d["level"], "title": d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
@@ -828,17 +858,19 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
for id in ti.get("ids", []):
if id not in id2score:
id2score[id] = []
- id2score[id].append(sc["score"]/5.)
+ id2score[id].append(sc["score"] / 5.)
for id in id2score.keys():
id2score[id] = np.mean(id2score[id])
- return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
+ return [(id, sc) for id, sc in list(id2score.items()) if sc >= 0.3][:topn]
except Exception as e:
logging.exception(e)
return []
META_DATA = load_prompt("meta_data")
-async def gen_metadata(chat_mdl, schema:dict, content:str):
+
+
+async def gen_metadata(chat_mdl, schema: dict, content: str):
template = PROMPT_JINJA_ENV.from_string(META_DATA)
for k, desc in schema["properties"].items():
if "enum" in desc and not desc.get("enum"):
@@ -849,4 +881,4 @@ async def gen_metadata(chat_mdl, schema:dict, content:str):
user_prompt = "Output: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
- return re.sub(r"^.*", "", ans, flags=re.DOTALL)
\ No newline at end of file
+ return re.sub(r"^.*", "", ans, flags=re.DOTALL)
diff --git a/rag/prompts/template.py b/rag/prompts/template.py
index 654e71c5c..69079818a 100644
--- a/rag/prompts/template.py
+++ b/rag/prompts/template.py
@@ -1,6 +1,5 @@
import os
-
PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {}
diff --git a/rag/svr/cache_file_svr.py b/rag/svr/cache_file_svr.py
index 3744c04ea..f71712f79 100644
--- a/rag/svr/cache_file_svr.py
+++ b/rag/svr/cache_file_svr.py
@@ -48,13 +48,15 @@ def main():
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))
except Exception as e:
- traceback.print_stack(e)
+ logging.error(f"Error to get data from REDIS: {e}")
+ traceback.print_stack()
except Exception as e:
- traceback.print_stack(e)
+ logging.error(f"Error to check REDIS connection: {e}")
+ traceback.print_stack()
if __name__ == "__main__":
while True:
main()
close_connection()
- time.sleep(1)
\ No newline at end of file
+ time.sleep(1)
diff --git a/rag/svr/discord_svr.py b/rag/svr/discord_svr.py
index ec842c45c..1c663d708 100644
--- a/rag/svr/discord_svr.py
+++ b/rag/svr/discord_svr.py
@@ -19,16 +19,15 @@ import requests
import base64
import asyncio
-URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
+URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
JSON_DATA = {
- "conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
- "Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
- "word": "" # User question, don't need to initialize
+ "conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
+ "Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
+ "word": "" # User question, don't need to initialize
}
-DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application
-
+DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application
intents = discord.Intents.default()
intents.message_content = True
@@ -50,7 +49,7 @@ async def on_message(message):
if len(message.content.split('> ')) == 1:
await message.channel.send("Hi~ How can I help you? ")
else:
- JSON_DATA['word']=message.content.split('> ')[1]
+ JSON_DATA['word'] = message.content.split('> ')[1]
response = requests.post(URL, json=JSON_DATA)
response_data = response.json().get('data', [])
image_bool = False
@@ -61,9 +60,9 @@ async def on_message(message):
if i['type'] == 3:
image_bool = True
image_data = base64.b64decode(i['url'])
- with open('tmp_image.png','wb') as file:
+ with open('tmp_image.png', 'wb') as file:
file.write(image_data)
- image= discord.File('tmp_image.png')
+ image = discord.File('tmp_image.png')
await message.channel.send(f"{message.author.mention}{res}")
diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py
index d79c2a20c..3e2172d4b 100644
--- a/rag/svr/sync_data_source.py
+++ b/rag/svr/sync_data_source.py
@@ -38,7 +38,8 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings
from common.config_utils import show_configs
-from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
+from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, \
+ MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
from common.constants import FileSource, TaskStatus
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.confluence_connector import ConfluenceConnector
@@ -96,7 +97,7 @@ class SyncBase:
if task["poll_range_start"]:
next_update = task["poll_range_start"]
- for document_batch in document_batch_generator:
+ for document_batch in document_batch_generator:
if not document_batch:
continue
@@ -161,6 +162,7 @@ class SyncBase:
def _get_source_prefix(self):
return ""
+
class _BlobLikeBase(SyncBase):
DEFAULT_BUCKET_TYPE: str = "s3"
@@ -199,22 +201,27 @@ class _BlobLikeBase(SyncBase):
)
return document_batch_generator
+
class S3(_BlobLikeBase):
SOURCE_NAME: str = FileSource.S3
DEFAULT_BUCKET_TYPE: str = "s3"
+
class R2(_BlobLikeBase):
SOURCE_NAME: str = FileSource.R2
DEFAULT_BUCKET_TYPE: str = "r2"
+
class OCI_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.OCI_STORAGE
DEFAULT_BUCKET_TYPE: str = "oci_storage"
+
class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
+
class Confluence(SyncBase):
SOURCE_NAME: str = FileSource.CONFLUENCE
@@ -248,7 +255,9 @@ class Confluence(SyncBase):
index_recursively=index_recursively,
)
- credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
+ credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"],
+ connector_name=DocumentSource.CONFLUENCE,
+ credential_json=self.conf["credentials"])
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
@@ -280,7 +289,8 @@ class Confluence(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator:
if failure is not None:
- logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
+ logging.warning("Confluence connector failure: %s",
+ getattr(failure, "failure_message", failure))
continue
if document is not None:
pending_docs.append(document)
@@ -300,7 +310,7 @@ class Confluence(SyncBase):
async def async_wrapper():
for batch in document_batches():
yield batch
-
+
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
return async_wrapper()
@@ -314,10 +324,12 @@ class Notion(SyncBase):
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
- else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
+ else self.connector.poll_source(task["poll_range_start"].timestamp(),
+ datetime.now(timezone.utc).timestamp())
)
- begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
+ begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
+ task["poll_range_start"])
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
return document_generator
@@ -340,10 +352,12 @@ class Discord(SyncBase):
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
- else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
+ else self.connector.poll_source(task["poll_range_start"].timestamp(),
+ datetime.now(timezone.utc).timestamp())
)
- begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
+ begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
+ task["poll_range_start"])
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
return document_generator
@@ -485,7 +499,8 @@ class GoogleDrive(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator:
if failure is not None:
- logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
+ logging.warning("Google Drive connector failure: %s",
+ getattr(failure, "failure_message", failure))
continue
if document is not None:
pending_docs.append(document)
@@ -646,10 +661,10 @@ class WebDAV(SyncBase):
remote_path=self.conf.get("remote_path", "/")
)
self.connector.load_credentials(self.conf["credentials"])
-
+
logging.info(f"Task info: reindex={task['reindex']}, poll_range_start={task['poll_range_start']}")
-
- if task["reindex"]=="1" or not task["poll_range_start"]:
+
+ if task["reindex"] == "1" or not task["poll_range_start"]:
logging.info("Using load_from_state (full sync)")
document_batch_generator = self.connector.load_from_state()
begin_info = "totally"
@@ -659,14 +674,15 @@ class WebDAV(SyncBase):
logging.info(f"Polling WebDAV from {task['poll_range_start']} (ts: {start_ts}) to now (ts: {end_ts})")
document_batch_generator = self.connector.poll_source(start_ts, end_ts)
begin_info = "from {}".format(task["poll_range_start"])
-
+
logging.info("Connect to WebDAV: {}(path: {}) {}".format(
self.conf["base_url"],
self.conf.get("remote_path", "/"),
begin_info
))
return document_batch_generator
-
+
+
class Moodle(SyncBase):
SOURCE_NAME: str = FileSource.MOODLE
@@ -675,7 +691,7 @@ class Moodle(SyncBase):
moodle_url=self.conf["moodle_url"],
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)
)
-
+
self.connector.load_credentials(self.conf["credentials"])
# Determine the time range for synchronization based on reindex or poll_range_start
@@ -689,7 +705,7 @@ class Moodle(SyncBase):
begin_info = "totally"
else:
document_generator = self.connector.poll_source(
- poll_start.timestamp(),
+ poll_start.timestamp(),
datetime.now(timezone.utc).timestamp()
)
begin_info = "from {}".format(poll_start)
@@ -718,7 +734,7 @@ class BOX(SyncBase):
token = AccessToken(
access_token=credential['access_token'],
refresh_token=credential['refresh_token'],
- )
+ )
auth.token_storage.store(token)
self.connector.load_credentials(auth)
@@ -739,6 +755,7 @@ class BOX(SyncBase):
logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info))
return document_generator
+
class Airtable(SyncBase):
SOURCE_NAME: str = FileSource.AIRTABLE
@@ -784,6 +801,7 @@ class Airtable(SyncBase):
return document_generator
+
func_factory = {
FileSource.S3: S3,
FileSource.R2: R2,
diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py
index 9a4b70364..ab75cd0c1 100644
--- a/rag/svr/task_executor.py
+++ b/rag/svr/task_executor.py
@@ -92,7 +92,7 @@ FACTORY = {
}
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
- "dataflow" : PipelineTaskType.PARSE,
+ "dataflow": PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
"mindmap": PipelineTaskType.MINDMAP,
@@ -221,7 +221,7 @@ async def get_storage_binary(bucket, name):
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
-@timeout(60*80, 1)
+@timeout(60 * 80, 1)
async def build_chunks(task, progress_callback):
if task["size"] > settings.DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
@@ -283,7 +283,8 @@ async def build_chunks(task, progress_callback):
try:
d = copy.deepcopy(document)
d.update(chunk)
- d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
+ d["id"] = xxhash.xxh64(
+ (chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
@@ -328,9 +329,11 @@ async def build_chunks(task, progress_callback):
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
+
tasks = []
for d in docs:
- tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
+ tasks.append(
+ asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
@@ -355,9 +358,11 @@ async def build_chunks(task, progress_callback):
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
+
tasks = []
for d in docs:
- tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
+ tasks.append(
+ asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
@@ -374,15 +379,18 @@ async def build_chunks(task, progress_callback):
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
async def gen_metadata_task(chat_mdl, d):
- cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", task["parser_config"]["metadata"])
+ cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
+ task["parser_config"]["metadata"])
if not cached:
async with chat_limiter:
cached = await gen_metadata(chat_mdl,
metadata_schema(task["parser_config"]["metadata"]),
d["content_with_weight"])
- set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", task["parser_config"]["metadata"])
+ set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
+ task["parser_config"]["metadata"])
if cached:
d["metadata_obj"] = cached
+
tasks = []
for d in docs:
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
@@ -430,7 +438,8 @@ async def build_chunks(task, progress_callback):
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return None
- if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
+ if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
+ d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
docs_to_tag.append(d)
@@ -438,7 +447,7 @@ async def build_chunks(task, progress_callback):
async def doc_content_tagging(chat_mdl, d, topn_tags):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
- picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
+ picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter:
@@ -454,6 +463,7 @@ async def build_chunks(task, progress_callback):
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
+
tasks = []
for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
@@ -473,21 +483,22 @@ async def build_chunks(task, progress_callback):
def build_TOC(task, docs, progress_callback):
progress_callback(msg="Start to generate table of content ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
- docs = sorted(docs, key=lambda d:(
+ docs = sorted(docs, key=lambda d: (
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
- toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
- logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
+ toc: list[dict] = asyncio.run(
+ run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
+ logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
- if ii == len(toc) -1:
+ if ii == len(toc) - 1:
break
- for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
+ for jj in range(idx + 1, int(toc[ii + 1]["chunk_id"]) + 1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
@@ -499,7 +510,8 @@ def build_TOC(task, docs, progress_callback):
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
- d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
+ d["id"] = xxhash.xxh64(
+ (d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
@@ -532,12 +544,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
@timeout(60)
def batch_encode(txts):
nonlocal mdl
- return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
+ return mdl.encode([truncate(c, mdl.max_length - 10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
- vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
+ vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0:
cnts_ = vts
else:
@@ -545,7 +557,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
cnts = cnts_
- filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
+ filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
@@ -588,7 +600,8 @@ async def run_dataflow(task: dict):
return
if not chunks:
- PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
+ PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
+ task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
@@ -610,25 +623,27 @@ async def run_dataflow(task: dict):
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
+
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
+
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
- delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1)
+ delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
- vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
+ vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
- if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1:
- set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}")
+ if i % (len(texts) // settings.EMBEDDING_BATCH_SIZE / 100 + 1) == 1:
+ set_progress(task_id, prog=prog, msg=f"{i + 1} / {len(texts) // settings.EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
@@ -636,10 +651,10 @@ async def run_dataflow(task: dict):
ck["q_%d_vec" % len(v)] = v
except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
- PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
+ PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
+ task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
-
metadata = {}
for ck in chunks:
ck["doc_id"] = doc_id
@@ -686,15 +701,19 @@ async def run_dataflow(task: dict):
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
- PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
+ PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
+ task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
- DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
- logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
- PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
+ DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
+ task_time_cost)
+ logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
+ task_time_cost))
+ PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
+ dsl=str(pipeline))
@timeout(3600)
@@ -702,7 +721,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
- vctr_nm = "q_%d_vec"%vector_size
+ vctr_nm = "q_%d_vec" % vector_size
res = []
tk_count = 0
@@ -747,17 +766,17 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for x, doc_id in enumerate(doc_ids):
chunks = []
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
- fields=["content_with_weight", vctr_nm],
- sort_by_position=True):
+ fields=["content_with_weight", vctr_nm],
+ sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, doc_id)
- callback(prog=(x+1.)/len(doc_ids))
+ callback(prog=(x + 1.) / len(doc_ids))
else:
chunks = []
for doc_id in doc_ids:
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
- fields=["content_with_weight", vctr_nm],
- sort_by_position=True):
+ fields=["content_with_weight", vctr_nm],
+ sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, fake_doc_id)
@@ -792,19 +811,22 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mom_ck["available_int"] = 0
flds = list(mom_ck.keys())
for fld in flds:
- if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int"]:
+ if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
+ "position_int"]:
del mom_ck[fld]
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
- await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
+ await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
+ search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
- doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
+ doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
+ search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@@ -821,7 +843,8 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
- doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
+ doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
+ search.index_name(task_tenant_id), task_dataset_id, )
tasks = []
for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
@@ -838,7 +861,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
return True
-@timeout(60*60*3, 1)
+@timeout(60 * 60 * 3, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
@@ -914,7 +937,7 @@ async def do_handle_task(task):
},
}
)
- if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
+ if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
@@ -943,7 +966,7 @@ async def do_handle_task(task):
doc_ids=task.get("doc_ids", []),
)
if fake_doc_ids := task.get("doc_ids", []):
- task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
+ task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
# Either using graphrag or Standard chunking methods
elif task_type == "graphrag":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
@@ -968,11 +991,10 @@ async def do_handle_task(task):
}
}
)
- if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
+ if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
-
graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
@@ -1030,7 +1052,7 @@ async def do_handle_task(task):
return True
e = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(e)
-
+
try:
if not await _maybe_insert_es(chunks):
return
@@ -1084,8 +1106,8 @@ async def do_handle_task(task):
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled."
)
-async def handle_task():
+async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
@@ -1093,7 +1115,8 @@ async def handle_task():
return
task_type = task["task_type"]
- pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
+ pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
+ PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
@@ -1119,7 +1142,9 @@ async def handle_task():
if task_type in ["graphrag", "raptor", "mindmap"]:
task_document_ids = task["doc_ids"]
if not task.get("dataflow_id", ""):
- PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
+ PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
+ task_type=pipeline_task_type,
+ fake_document_ids=task_document_ids)
redis_msg.ack()
@@ -1249,6 +1274,7 @@ async def main():
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!")
+
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)
diff --git a/rag/utils/azure_spn_conn.py b/rag/utils/azure_spn_conn.py
index 005d3ba6b..12bcc6410 100644
--- a/rag/utils/azure_spn_conn.py
+++ b/rag/utils/azure_spn_conn.py
@@ -42,8 +42,10 @@ class RAGFlowAzureSpnBlob:
pass
try:
- credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
- self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials)
+ credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id,
+ client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
+ self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name,
+ credential=credentials)
except Exception:
logging.exception("Fail to connect %s" % self.account_url)
@@ -104,4 +106,4 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
- return None
\ No newline at end of file
+ return None
diff --git a/rag/utils/base64_image.py b/rag/utils/base64_image.py
index 935979710..ecdf24387 100644
--- a/rag/utils/base64_image.py
+++ b/rag/utils/base64_image.py
@@ -25,7 +25,8 @@ from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64)
-async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
+
+async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str = "imagetemps"):
import logging
from io import BytesIO
from rag.svr.task_executor import minio_limiter
@@ -74,7 +75,7 @@ async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="
del d["image"]
-def id2image(image_id:str|None, storage_get_func: partial):
+def id2image(image_id: str | None, storage_get_func: partial):
if not image_id:
return
arr = image_id.split("-")
diff --git a/rag/utils/encrypted_storage.py b/rag/utils/encrypted_storage.py
index 19e199f4e..e5ac9cf97 100644
--- a/rag/utils/encrypted_storage.py
+++ b/rag/utils/encrypted_storage.py
@@ -16,11 +16,13 @@
import logging
from common.crypto_utils import CryptoUtil
+
+
# from common.decorator import singleton
class EncryptedStorageWrapper:
"""Encrypted storage wrapper that wraps existing storage implementations to provide transparent encryption"""
-
+
def __init__(self, storage_impl, algorithm="aes-256-cbc", key=None, iv=None):
"""
Initialize encrypted storage wrapper
@@ -34,16 +36,16 @@ class EncryptedStorageWrapper:
self.storage_impl = storage_impl
self.crypto = CryptoUtil(algorithm=algorithm, key=key, iv=iv)
self.encryption_enabled = True
-
+
# Check if storage implementation has required methods
# todo: Consider abstracting a storage base class to ensure these methods exist
required_methods = ["put", "get", "rm", "obj_exist", "health"]
for method in required_methods:
if not hasattr(storage_impl, method):
raise AttributeError(f"Storage implementation missing required method: {method}")
-
+
logging.info(f"EncryptedStorageWrapper initialized with algorithm: {algorithm}")
-
+
def put(self, bucket, fnm, binary, tenant_id=None):
"""
Encrypt and store data
@@ -59,15 +61,15 @@ class EncryptedStorageWrapper:
"""
if not self.encryption_enabled:
return self.storage_impl.put(bucket, fnm, binary, tenant_id)
-
+
try:
encrypted_binary = self.crypto.encrypt(binary)
-
+
return self.storage_impl.put(bucket, fnm, encrypted_binary, tenant_id)
except Exception as e:
logging.exception(f"Failed to encrypt and store data: {bucket}/{fnm}, error: {str(e)}")
raise
-
+
def get(self, bucket, fnm, tenant_id=None):
"""
Retrieve and decrypt data
@@ -83,21 +85,21 @@ class EncryptedStorageWrapper:
try:
# Get encrypted data
encrypted_binary = self.storage_impl.get(bucket, fnm, tenant_id)
-
+
if encrypted_binary is None:
return None
-
+
if not self.encryption_enabled:
return encrypted_binary
-
+
# Decrypt data
decrypted_binary = self.crypto.decrypt(encrypted_binary)
return decrypted_binary
-
+
except Exception as e:
logging.exception(f"Failed to get and decrypt data: {bucket}/{fnm}, error: {str(e)}")
raise
-
+
def rm(self, bucket, fnm, tenant_id=None):
"""
Delete data (same as original storage implementation, no decryption needed)
@@ -111,7 +113,7 @@ class EncryptedStorageWrapper:
Deletion result
"""
return self.storage_impl.rm(bucket, fnm, tenant_id)
-
+
def obj_exist(self, bucket, fnm, tenant_id=None):
"""
Check if object exists (same as original storage implementation, no decryption needed)
@@ -125,7 +127,7 @@ class EncryptedStorageWrapper:
Whether the object exists
"""
return self.storage_impl.obj_exist(bucket, fnm, tenant_id)
-
+
def health(self):
"""
Health check (uses the original storage implementation's method)
@@ -134,7 +136,7 @@ class EncryptedStorageWrapper:
Health check result
"""
return self.storage_impl.health()
-
+
def bucket_exists(self, bucket):
"""
Check if bucket exists (if the original storage implementation has this method)
@@ -148,7 +150,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "bucket_exists"):
return self.storage_impl.bucket_exists(bucket)
return False
-
+
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
"""
Get presigned URL (if the original storage implementation has this method)
@@ -165,7 +167,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "get_presigned_url"):
return self.storage_impl.get_presigned_url(bucket, fnm, expires, tenant_id)
return None
-
+
def scan(self, bucket, fnm, tenant_id=None):
"""
Scan objects (if the original storage implementation has this method)
@@ -181,7 +183,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "scan"):
return self.storage_impl.scan(bucket, fnm, tenant_id)
return None
-
+
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
"""
Copy object (if the original storage implementation has this method)
@@ -198,7 +200,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "copy"):
return self.storage_impl.copy(src_bucket, src_path, dest_bucket, dest_path)
return False
-
+
def move(self, src_bucket, src_path, dest_bucket, dest_path):
"""
Move object (if the original storage implementation has this method)
@@ -215,7 +217,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "move"):
return self.storage_impl.move(src_bucket, src_path, dest_bucket, dest_path)
return False
-
+
def remove_bucket(self, bucket):
"""
Remove bucket (if the original storage implementation has this method)
@@ -229,17 +231,18 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "remove_bucket"):
return self.storage_impl.remove_bucket(bucket)
return False
-
+
def enable_encryption(self):
"""Enable encryption"""
self.encryption_enabled = True
logging.info("Encryption enabled")
-
+
def disable_encryption(self):
"""Disable encryption"""
self.encryption_enabled = False
logging.info("Encryption disabled")
+
# Create singleton wrapper function
def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True):
"""
@@ -255,12 +258,12 @@ def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_
Encrypted storage wrapper instance
"""
wrapper = EncryptedStorageWrapper(storage_impl, algorithm=algorithm, key=key)
-
+
wrapper.encryption_enabled = encryption_enabled
-
+
if encryption_enabled:
logging.info("Encryption enabled in storage wrapper")
else:
logging.info("Encryption disabled in storage wrapper")
-
+
return wrapper
diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py
index c991ac2a8..1d7b02e36 100644
--- a/rag/utils/es_conn.py
+++ b/rag/utils/es_conn.py
@@ -32,7 +32,6 @@ ATTEMPT_TIME = 2
@singleton
class ESConnection(ESConnectionBase):
-
"""
CRUD operations
"""
@@ -82,8 +81,9 @@ class ESConnection(ESConnectionBase):
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
- assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
- MatchDenseExpr) and isinstance(
+ assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
+ match_expressions[1],
+ MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
@@ -93,9 +93,9 @@ class ESConnection(ESConnectionBase):
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=m.fields,
- type="best_fields", query=m.matching_text,
- minimum_should_match=minimum_should_match,
- boost=1))
+ type="best_fields", query=m.matching_text,
+ minimum_should_match=minimum_should_match,
+ boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
@@ -146,7 +146,7 @@ class ESConnection(ESConnectionBase):
for i in range(ATTEMPT_TIME):
try:
- #print(json.dumps(q, ensure_ascii=False))
+ # print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names,
body=q,
timeout="600s",
@@ -220,13 +220,15 @@ class ESConnection(ESConnectionBase):
try:
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
- self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
+ self.logger.exception(
+ f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=index_name, id=chunk_id, doc=doc)
return True
except Exception as e:
self.logger.exception(
- f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
+ f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(
+ e))
break
return False
diff --git a/rag/utils/file_utils.py b/rag/utils/file_utils.py
index 18cdc35e1..8d19079b7 100644
--- a/rag/utils/file_utils.py
+++ b/rag/utils/file_utils.py
@@ -25,18 +25,23 @@ import PyPDF2
from docx import Document
import olefile
+
def _is_zip(h: bytes) -> bool:
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
+
def _is_pdf(h: bytes) -> bool:
return h.startswith(b"%PDF-")
+
def _is_ole(h: bytes) -> bool:
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
+
def _sha10(b: bytes) -> str:
return hashlib.sha256(b).hexdigest()[:10]
+
def _guess_ext(b: bytes) -> str:
h = b[:8]
if _is_zip(h):
@@ -58,13 +63,14 @@ def _guess_ext(b: bytes) -> str:
return ".doc"
return ".bin"
+
# Try to extract the real embedded payload from OLE's Ole10Native
def _extract_ole10native_payload(data: bytes) -> bytes:
try:
pos = 0
if len(data) < 4:
return data
- _ = int.from_bytes(data[pos:pos+4], "little")
+ _ = int.from_bytes(data[pos:pos + 4], "little")
pos += 4
# filename/src/tmp (NUL-terminated ANSI)
for _ in range(3):
@@ -74,14 +80,15 @@ def _extract_ole10native_payload(data: bytes) -> bytes:
pos += 4
if pos + 4 > len(data):
return data
- size = int.from_bytes(data[pos:pos+4], "little")
+ size = int.from_bytes(data[pos:pos + 4], "little")
pos += 4
if pos + size <= len(data):
- return data[pos:pos+size]
+ return data[pos:pos + size]
except Exception:
pass
return data
+
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
"""
Only extract the 'first layer' of embedding, returning raw (filename, bytes).
@@ -163,7 +170,7 @@ def extract_links_from_docx(docx_bytes: bytes):
# Each relationship may represent a hyperlink, image, footer, etc.
for rel in document.part.rels.values():
if rel.reltype == (
- "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
+ "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
):
links.add(rel.target_ref)
@@ -198,6 +205,8 @@ def extract_links_from_pdf(pdf_bytes: bytes):
_GLOBAL_SESSION: Optional[requests.Session] = None
+
+
def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
"""Get or create a global reusable session."""
global _GLOBAL_SESSION
@@ -216,10 +225,10 @@ def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
def extract_html(
- url: str,
- timeout: float = 60.0,
- headers: Optional[Dict[str, str]] = None,
- max_retries: int = 2,
+ url: str,
+ timeout: float = 60.0,
+ headers: Optional[Dict[str, str]] = None,
+ max_retries: int = 2,
) -> Tuple[Optional[bytes], Dict[str, str]]:
"""
Extract the full HTML page as raw bytes from a given URL.
@@ -260,4 +269,4 @@ def extract_html(
metadata["error"] = f"Request failed: {e}"
continue
- return None, metadata
\ No newline at end of file
+ return None, metadata
diff --git a/rag/utils/gcs_conn.py b/rag/utils/gcs_conn.py
index 5268cea42..53aa3d4e5 100644
--- a/rag/utils/gcs_conn.py
+++ b/rag/utils/gcs_conn.py
@@ -204,4 +204,4 @@ class RAGFlowGCS:
return False
except Exception:
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
- return False
\ No newline at end of file
+ return False
diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py
index 8805a754b..79f871e80 100644
--- a/rag/utils/infinity_conn.py
+++ b/rag/utils/infinity_conn.py
@@ -28,7 +28,6 @@ from common.doc_store.infinity_conn_base import InfinityConnectionBase
@singleton
class InfinityConnection(InfinityConnectionBase):
-
"""
Dataframe and fields convert
"""
@@ -83,24 +82,23 @@ class InfinityConnection(InfinityConnectionBase):
tokens[0] = field
return "^".join(tokens)
-
"""
CRUD operations
"""
def search(
- self,
- select_fields: list[str],
- highlight_fields: list[str],
- condition: dict,
- match_expressions: list[MatchExpr],
- order_by: OrderByExpr,
- offset: int,
- limit: int,
- index_names: str | list[str],
- knowledgebase_ids: list[str],
- agg_fields: list[str] | None = None,
- rank_feature: dict | None = None,
+ self,
+ select_fields: list[str],
+ highlight_fields: list[str],
+ condition: dict,
+ match_expressions: list[MatchExpr],
+ order_by: OrderByExpr,
+ offset: int,
+ limit: int,
+ index_names: str | list[str],
+ knowledgebase_ids: list[str],
+ agg_fields: list[str] | None = None,
+ rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
@@ -159,7 +157,8 @@ class InfinityConnection(InfinityConnectionBase):
if table_found:
break
if not table_found:
- self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
+ self.logger.error(
+ f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
return pd.DataFrame(), 0
for matchExpr in match_expressions:
@@ -280,7 +279,8 @@ class InfinityConnection(InfinityConnectionBase):
try:
table_instance = db_instance.get_table(table_name)
except Exception:
- self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
+ self.logger.warning(
+ f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
@@ -288,7 +288,9 @@ class InfinityConnection(InfinityConnectionBase):
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
- for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
+ for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd",
+ "question_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks",
+ "authors_sm_tks"]:
fields.add(field)
res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunk_id, None)
@@ -379,7 +381,9 @@ class InfinityConnection(InfinityConnectionBase):
d[k] = "_".join(f"{num:08x}" for num in v)
else:
d[k] = v
- for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
+ for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
+ "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd",
+ "question_tks"]:
if k in d:
del d[k]
@@ -478,7 +482,8 @@ class InfinityConnection(InfinityConnectionBase):
del new_value[k]
else:
new_value[k] = v
- for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
+ for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
+ "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in new_value:
del new_value[k]
@@ -502,7 +507,8 @@ class InfinityConnection(InfinityConnectionBase):
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv)
- table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
+ table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])),
+ {k: "###".join(v)})
table_instance.update(filter, new_value)
self.connPool.release_conn(inf_conn)
@@ -561,7 +567,7 @@ class InfinityConnection(InfinityConnectionBase):
def to_position_int(v):
if v:
arr = [int(hex_val, 16) for hex_val in v.split("_")]
- v = [arr[i : i + 5] for i in range(0, len(arr), 5)]
+ v = [arr[i: i + 5] for i in range(0, len(arr), 5)]
else:
v = []
return v
diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py
index a81fb38ab..2c7b35ff6 100644
--- a/rag/utils/minio_conn.py
+++ b/rag/utils/minio_conn.py
@@ -46,6 +46,7 @@ class RAGFlowMinio:
# pass original identifier forward for use by other decorators
kwargs['_orig_bucket'] = original_bucket
return method(self, actual_bucket, *args, **kwargs)
+
return wrapper
@staticmethod
@@ -71,6 +72,7 @@ class RAGFlowMinio:
fnm = f"{orig_bucket}/{fnm}"
return method(self, bucket, fnm, *args, **kwargs)
+
return wrapper
def __open__(self):
diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py
index 0786e9140..d43f8bb75 100644
--- a/rag/utils/ob_conn.py
+++ b/rag/utils/ob_conn.py
@@ -37,7 +37,8 @@ from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.float_utils import get_float
-from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
+from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
+ MatchDenseExpr
from rag.nlp import rag_tokenizer
ATTEMPT_TIME = 2
@@ -719,19 +720,19 @@ class OBConnection(DocStoreConnection):
"""
def search(
- self,
- selectFields: list[str],
- highlightFields: list[str],
- condition: dict,
- matchExprs: list[MatchExpr],
- orderBy: OrderByExpr,
- offset: int,
- limit: int,
- indexNames: str | list[str],
- knowledgebaseIds: list[str],
- aggFields: list[str] = [],
- rank_feature: dict | None = None,
- **kwargs,
+ self,
+ selectFields: list[str],
+ highlightFields: list[str],
+ condition: dict,
+ matchExprs: list[MatchExpr],
+ orderBy: OrderByExpr,
+ offset: int,
+ limit: int,
+ indexNames: str | list[str],
+ knowledgebaseIds: list[str],
+ aggFields: list[str] = [],
+ rank_feature: dict | None = None,
+ **kwargs,
):
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
@@ -1546,7 +1547,7 @@ class OBConnection(DocStoreConnection):
flags=re.IGNORECASE | re.MULTILINE,
)
if len(re.findall(r'', highlighted_txt)) > 0 or len(
- re.findall(r'\s*', highlighted_txt)) > 0:
+ re.findall(r'\s*', highlighted_txt)) > 0:
return highlighted_txt
else:
return None
@@ -1565,9 +1566,9 @@ class OBConnection(DocStoreConnection):
if token_pos != -1:
if token in keywords:
highlighted_txt = (
- highlighted_txt[:token_pos] +
- f'{token}' +
- highlighted_txt[token_pos + len(token):]
+ highlighted_txt[:token_pos] +
+ f'{token}' +
+ highlighted_txt[token_pos + len(token):]
)
last_pos = token_pos
return re.sub(r'', '', highlighted_txt)
diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py
index b188346c7..936081384 100644
--- a/rag/utils/opendal_conn.py
+++ b/rag/utils/opendal_conn.py
@@ -6,7 +6,6 @@ from urllib.parse import quote_plus
from common.config_utils import get_base_config
from common.decorator import singleton
-
CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS `{}` (
`key` VARCHAR(255) PRIMARY KEY,
@@ -36,7 +35,8 @@ def get_opendal_config():
"table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"),
"max_allowed_packet": str(max_packet)
}
- kwargs["connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
+ kwargs[
+ "connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
else:
scheme = opendal_config.get("scheme")
config_data = opendal_config.get("config", {})
@@ -61,7 +61,7 @@ def get_opendal_config():
del kwargs["password"]
if "connection_string" in kwargs:
del kwargs["connection_string"]
- return kwargs
+ return kwargs
except Exception as e:
logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))
raise
@@ -99,7 +99,6 @@ class OpenDALStorage:
def obj_exist(self, bucket, fnm, tenant_id=None):
return self._operator.exists(f"{bucket}/{fnm}")
-
def init_db_config(self):
try:
conn = pymysql.connect(
diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py
index 2e828be6e..67e7364fe 100644
--- a/rag/utils/opensearch_conn.py
+++ b/rag/utils/opensearch_conn.py
@@ -26,7 +26,8 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
-from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
+from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
+ FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
@@ -189,7 +190,7 @@ class OSConnection(DocStoreConnection):
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
-
+
# Elasticsearch has the encapsulation of KNN_search in python sdk
# while the Python SDK for OpenSearch does not provide encapsulation for KNN_search,
# the following codes implement KNN_search in OpenSearch using DSL
@@ -216,7 +217,7 @@ class OSConnection(DocStoreConnection):
if bqry:
s = s.query(bqry)
for field in highlightFields:
- s = s.highlight(field,force_source=True,no_match_size=30,require_field_match=False)
+ s = s.highlight(field, force_source=True, no_match_size=30, require_field_match=False)
if orderBy:
orders = list()
@@ -239,10 +240,10 @@ class OSConnection(DocStoreConnection):
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"OSConnection.search {str(indexNames)} query: " + json.dumps(q))
-
+
if use_knn:
del q["query"]
- q["query"] = {"knn" : knn_query}
+ q["query"] = {"knn": knn_query}
for i in range(ATTEMPT_TIME):
try:
@@ -328,7 +329,7 @@ class OSConnection(DocStoreConnection):
chunkId = condition["id"]
for i in range(ATTEMPT_TIME):
try:
- self.os.update(index=indexName, id=chunkId, body={"doc":doc})
+ self.os.update(index=indexName, id=chunkId, body={"doc": doc})
return True
except Exception as e:
logger.exception(
@@ -435,7 +436,7 @@ class OSConnection(DocStoreConnection):
logger.debug("OSConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
- #print(Search().query(qry).to_dict(), flush=True)
+ # print(Search().query(qry).to_dict(), flush=True)
res = self.os.delete_by_query(
index=indexName,
body=Search().query(qry).to_dict(),
diff --git a/rag/utils/oss_conn.py b/rag/utils/oss_conn.py
index b0114f668..60f7e6e96 100644
--- a/rag/utils/oss_conn.py
+++ b/rag/utils/oss_conn.py
@@ -42,14 +42,16 @@ class RAGFlowOSS:
# If there is a default bucket, use the default bucket
actual_bucket = self.bucket if self.bucket else bucket
return method(self, actual_bucket, *args, **kwargs)
+
return wrapper
-
+
@staticmethod
def use_prefix_path(method):
def wrapper(self, bucket, fnm, *args, **kwargs):
# If the prefix path is set, use the prefix path
fnm = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm
return method(self, bucket, fnm, *args, **kwargs)
+
return wrapper
def __open__(self):
@@ -171,4 +173,3 @@ class RAGFlowOSS:
self.__open__()
time.sleep(1)
return None
-
diff --git a/rag/utils/raptor_utils.py b/rag/utils/raptor_utils.py
index c48e0999b..dd6f75dd9 100644
--- a/rag/utils/raptor_utils.py
+++ b/rag/utils/raptor_utils.py
@@ -21,7 +21,6 @@ Utility functions for Raptor processing decisions.
import logging
from typing import Optional
-
# File extensions for structured data types
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
CSV_EXTENSIONS = {".csv", ".tsv"}
@@ -40,12 +39,12 @@ def is_structured_file_type(file_type: Optional[str]) -> bool:
"""
if not file_type:
return False
-
+
# Normalize to lowercase and ensure leading dot
file_type = file_type.lower()
if not file_type.startswith("."):
file_type = f".{file_type}"
-
+
return file_type in STRUCTURED_EXTENSIONS
@@ -61,23 +60,23 @@ def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) ->
True if PDF is being parsed as tabular data
"""
parser_config = parser_config or {}
-
+
# If using table parser, it's tabular
if parser_id and parser_id.lower() == "table":
return True
-
+
# Check if html4excel is enabled (Excel-like table parsing)
if parser_config.get("html4excel", False):
return True
-
+
return False
def should_skip_raptor(
- file_type: Optional[str] = None,
- parser_id: str = "",
- parser_config: Optional[dict] = None,
- raptor_config: Optional[dict] = None
+ file_type: Optional[str] = None,
+ parser_id: str = "",
+ parser_config: Optional[dict] = None,
+ raptor_config: Optional[dict] = None
) -> bool:
"""
Determine if Raptor should be skipped for a given document.
@@ -97,30 +96,30 @@ def should_skip_raptor(
"""
parser_config = parser_config or {}
raptor_config = raptor_config or {}
-
+
# Check if auto-disable is explicitly disabled in config
if raptor_config.get("auto_disable_for_structured_data", True) is False:
logging.info("Raptor auto-disable is turned off via configuration")
return False
-
+
# Check for Excel/CSV files
if is_structured_file_type(file_type):
logging.info(f"Skipping Raptor for structured file type: {file_type}")
return True
-
+
# Check for tabular PDFs
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
logging.info(f"Skipping Raptor for tabular PDF (parser_id={parser_id})")
return True
-
+
return False
def get_skip_reason(
- file_type: Optional[str] = None,
- parser_id: str = "",
- parser_config: Optional[dict] = None
+ file_type: Optional[str] = None,
+ parser_id: str = "",
+ parser_config: Optional[dict] = None
) -> str:
"""
Get a human-readable reason why Raptor was skipped.
@@ -134,12 +133,12 @@ def get_skip_reason(
Reason string, or empty string if Raptor should not be skipped
"""
parser_config = parser_config or {}
-
+
if is_structured_file_type(file_type):
return f"Structured data file ({file_type}) - Raptor auto-disabled"
-
+
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
return f"Tabular PDF (parser={parser_id}) - Raptor auto-disabled"
-
+
return ""
diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py
index 4bea635bb..fd5903ce1 100644
--- a/rag/utils/redis_conn.py
+++ b/rag/utils/redis_conn.py
@@ -33,6 +33,7 @@ except Exception:
except Exception:
REDIS = {}
+
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer
@@ -278,7 +279,8 @@ class RedisDB:
def decrby(self, key: str, decrement: int):
return self.REDIS.decrby(key, decrement)
- def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
+ def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default",
+ increment: int = 1, ensure_minimum: int | None = None) -> int:
redis_key = f"{key_prefix}:{namespace}"
try:
diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py
index 11ac65cee..0e3ab4b43 100644
--- a/rag/utils/s3_conn.py
+++ b/rag/utils/s3_conn.py
@@ -46,6 +46,7 @@ class RAGFlowS3:
# If there is a default bucket, use the default bucket
actual_bucket = self.bucket if self.bucket else bucket
return method(self, actual_bucket, *args, **kwargs)
+
return wrapper
@staticmethod
@@ -57,6 +58,7 @@ class RAGFlowS3:
if self.prefix_path:
fnm = f"{self.prefix_path}/{bucket}/{fnm}"
return method(self, bucket, fnm, *args, **kwargs)
+
return wrapper
def __open__(self):
@@ -81,16 +83,16 @@ class RAGFlowS3:
s3_params['region_name'] = self.region_name
if self.endpoint_url:
s3_params['endpoint_url'] = self.endpoint_url
-
+
# Configure signature_version and addressing_style through Config object
if self.signature_version:
config_kwargs['signature_version'] = self.signature_version
if self.addressing_style:
config_kwargs['s3'] = {'addressing_style': self.addressing_style}
-
+
if config_kwargs:
s3_params['config'] = Config(**config_kwargs)
-
+
self.conn = [boto3.client('s3', **s3_params)]
except Exception:
logging.exception(f"Fail to connect at region {self.region_name} or endpoint {self.endpoint_url}")
@@ -184,9 +186,9 @@ class RAGFlowS3:
for _ in range(10):
try:
r = self.conn[0].generate_presigned_url('get_object',
- Params={'Bucket': bucket,
- 'Key': fnm},
- ExpiresIn=expires)
+ Params={'Bucket': bucket,
+ 'Key': fnm},
+ ExpiresIn=expires)
return r
except Exception:
diff --git a/rag/utils/tavily_conn.py b/rag/utils/tavily_conn.py
index d57271716..1b391fb1b 100644
--- a/rag/utils/tavily_conn.py
+++ b/rag/utils/tavily_conn.py
@@ -30,7 +30,8 @@ class Tavily:
search_depth="advanced",
max_results=6
)
- return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]]
+ return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res
+ in response["results"]]
except Exception as e:
logging.exception(e)
@@ -64,5 +65,5 @@ class Tavily:
"count": 1,
"url": r["url"]
})
- logging.info("[Tavily]R: "+r["content"][:128]+"...")
- return {"chunks": chunks, "doc_aggs": aggs}
\ No newline at end of file
+ logging.info("[Tavily]R: " + r["content"][:128] + "...")
+ return {"chunks": chunks, "doc_aggs": aggs}