Fix IDE warnings (#12281)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-12-29 12:01:18 +08:00
committed by GitHub
parent 647fb115a0
commit 01f0ced1e6
43 changed files with 817 additions and 637 deletions

View File

@ -34,7 +34,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
if not ext:
raise RuntimeError("No extension detected.")
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma",
".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
raise RuntimeError(f"Extension {ext} is not supported yet.")
tmp_path = ""

View File

@ -22,7 +22,7 @@ from deepdoc.parser.utils import get_text
from rag.app import naive
from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
from rag.nlp import bullets_category, is_english,remove_contents_table, \
from rag.nlp import bullets_category, is_english, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks, attach_media_context
from rag.nlp import rag_tokenizer
@ -91,9 +91,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
filename, binary=binary, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
tbls = vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs)
# tbls = [((None, lns), None) for lns in tbls]
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
sections = [(item[0], item[1] if item[1] is not None else "") for item in sections if
not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
@ -109,14 +110,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
layout_recognizer = layout_recognizer,
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
pdf_cls=Pdf,
layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
**kwargs
)
@ -126,7 +127,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
@ -175,7 +176,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
for ck in hierarchical_merge(bull, sections, 5)]
else:
sections = [s.split("@") for s, _ in sections]
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ]
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections]
chunks = naive_merge(
sections,
parser_config.get("chunk_token_num", 256),
@ -199,6 +200,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)

View File

@ -26,13 +26,13 @@ import io
def chunk(
filename,
binary=None,
from_page=0,
to_page=100000,
lang="Chinese",
callback=None,
**kwargs,
filename,
binary=None,
from_page=0,
to_page=100000,
lang="Chinese",
callback=None,
**kwargs,
):
"""
Only eml is supported
@ -93,7 +93,8 @@ def chunk(
_add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
(line, "") for line in
HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
]
st = timer()
@ -126,7 +127,9 @@ def chunk(
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -29,8 +29,6 @@ from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
class Docx(DocxParser):
def __init__(self):
pass
@ -58,37 +56,36 @@ class Docx(DocxParser):
return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
level_set = set()
bull = bullets_category([p.text for p in self.doc.paragraphs])
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
level_set = set()
bull = bullets_category([p.text for p in self.doc.paragraphs])
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):
continue
lines.append((question_level, p_text))
level_set.add(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
lines.append((question_level, p_text))
level_set.add(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
sorted_levels = sorted(level_set)
sorted_levels = sorted(level_set)
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
root = Node(level=0, depth=h2_level, texts=[])
root.build_tree(lines)
return [element for element in root.get_tree() if element]
root = Node(level=0, depth=h2_level, texts=[])
root.build_tree(lines)
return [element for element in root.get_tree() if element]
def __str__(self) -> str:
return f'''
@ -121,8 +118,7 @@ class Pdf(PdfParser):
start = timer()
self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts:".format(
))
logging.debug("layouts: {}".format((timer() - start)))
self._naive_vertical_merge()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
@ -154,7 +150,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
chunks = Docx()(filename, binary)
callback(0.7, "Finish parsing.")
return tokenize_chunks(chunks, doc, eng, None)
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer, parser_model_name = normalize_layout_recognizer(
parser_config.get("layout_recognize", "DeepDOC")
@ -168,14 +164,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
raw_sections, tables, pdf_parser = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
layout_recognizer = layout_recognizer,
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
pdf_cls=Pdf,
layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
**kwargs
)
@ -185,7 +181,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
for txt, poss in raw_sections:
sections.append(txt + poss)
@ -226,7 +222,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
# Remove 'Contents' part
remove_contents_table(sections, eng)
@ -234,7 +229,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
bull = bullets_category(sections)
res = tree_merge(bull, sections, 2)
if not res:
callback(0.99, "No chunk parsed out.")
@ -243,9 +237,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
# chunks = hierarchical_merge(bull, sections, 5)
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -20,15 +20,17 @@ import re
from common.constants import ParserType
from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level, attach_media_context
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, \
docx_question_level, attach_media_context
from common.token_utils import num_tokens_from_string
from deepdoc.parser import PdfParser, DocxParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper, vision_figure_parser_docx_wrapper
from docx import Document
from PIL import Image
from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.MANUAL.value
@ -129,11 +131,11 @@ class Docx(DocxParser):
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = self.concat_img(last_image, current_image)
else: # is a question
else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
@ -159,14 +161,14 @@ class Docx(DocxParser):
tbls = []
for tb in self.doc.tables:
html= "<table>"
html = "<table>"
for r in tb.rows:
html += "<tr>"
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i+1, len(r.cells)):
for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
@ -211,16 +213,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
kwargs.pop("parse_method", None)
kwargs.pop("mineru_llm_name", None)
sections, tbls, pdf_parser = pdf_parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
layout_recognizer = layout_recognizer,
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
pdf_cls=Pdf,
layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
parse_method = "manual",
parse_method="manual",
**kwargs
)
@ -237,10 +239,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if isinstance(poss, str):
poss = pdf_parser.extract_positions(poss)
if poss:
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
pn = first[0]
if isinstance(pn, list) and pn:
pn = pn[0] # [pn] -> pn
pn = pn[0] # [pn] -> pn
poss[0] = (pn, *first[1:])
return (txt, layoutno, poss)
@ -289,7 +291,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
def tag(pn, left, right, top, bottom):
if pn + left + right + top + bottom == 0:
@ -312,7 +314,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1:
last_sid = sec_id
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
@ -325,7 +327,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
tbls = vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_docx_wrapper(sections=ti_list, tbls=tbls, callback=callback, **kwargs)
res = tokenize_table(tbls, doc, eng)
for text, image in ti_list:
d = copy.deepcopy(doc)

View File

@ -31,16 +31,20 @@ from common.token_utils import num_tokens_from_string
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, \
PdfParser, TxtParser
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_docx_wrapper, \
vision_figure_parser_pdf_wrapper
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from deepdoc.parser.docling_parser import DoclingParser
from deepdoc.parser.tcadp_parser import TCADPParser
from common.parser_config_utils import normalize_layout_recognizer
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, \
tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
**kwargs):
callback = callback
binary = binary
pdf_parser = pdf_cls() if pdf_cls else Pdf()
@ -58,17 +62,17 @@ def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
def by_mineru(
filename,
binary=None,
from_page=0,
to_page=100000,
lang="Chinese",
callback=None,
pdf_cls=None,
parse_method: str = "raw",
mineru_llm_name: str | None = None,
tenant_id: str | None = None,
**kwargs,
filename,
binary=None,
from_page=0,
to_page=100000,
lang="Chinese",
callback=None,
pdf_cls=None,
parse_method: str = "raw",
mineru_llm_name: str | None = None,
tenant_id: str | None = None,
**kwargs,
):
pdf_parser = None
if tenant_id:
@ -106,7 +110,8 @@ def by_mineru(
return None, None, None
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
**kwargs):
pdf_parser = DoclingParser()
parse_method = kwargs.get("parse_method", "raw")
@ -125,7 +130,7 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
return sections, tables, pdf_parser
def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None, **kwargs):
tcadp_parser = TCADPParser()
if not tcadp_parser.check_installation():
@ -168,10 +173,10 @@ def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=No
PARSERS = {
"deepdoc": by_deepdoc,
"mineru": by_mineru,
"docling": by_docling,
"tcadp": by_tcadp,
"deepdoc": by_deepdoc,
"mineru": by_mineru,
"docling": by_docling,
"tcadp": by_tcadp,
"plaintext": by_plaintext, # default
}
@ -264,7 +269,7 @@ class Docx(DocxParser):
# Find the nearest heading paragraph in reverse order
nearest_title = None
for i in range(len(blocks)-1, -1, -1):
for i in range(len(blocks) - 1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
@ -293,7 +298,7 @@ class Docx(DocxParser):
# Find all parent headings, allowing cross-level search
while current_level > 1:
found = False
for i in range(len(blocks)-1, -1, -1):
for i in range(len(blocks) - 1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
@ -426,7 +431,8 @@ class Docx(DocxParser):
try:
if inline_images:
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64))
result = mammoth.convert_to_html(docx_file,
convert_image=mammoth.images.img_element(_convert_image_to_base64))
else:
result = mammoth.convert_to_html(docx_file)
@ -621,6 +627,7 @@ class Markdown(MarkdownParser):
return sections, tbls, section_images
return sections, tbls
def load_from_xml_v2(baseURI, rels_item_xml):
"""
Return |_SerializedRelationships| instance loaded with the
@ -636,6 +643,7 @@ def load_from_xml_v2(baseURI, rels_item_xml):
srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
return srels
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, excel, txt.
@ -651,7 +659,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode(
'latin1').decode('utf-8')
cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
if cust_child_deli:
@ -685,7 +694,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
# Recursively chunk each embedded file and collect results
for embed_filename, embed_bytes in embeds:
try:
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False, **kwargs) or []
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False,
**kwargs) or []
embed_res.extend(sub_res)
except Exception as e:
if callback:
@ -704,7 +714,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
except Exception as e:
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False,
**kwargs)
url_res.extend(sub_url_res)
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
@ -747,14 +758,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
layout_recognizer = layout_recognizer,
mineru_llm_name = parser_model_name,
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
**kwargs
)
@ -846,9 +857,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
else:
section_images = [None] * len(sections)
section_images[idx] = combined_image
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=[
((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
boosted_figures = markdown_vision_parser(callback=callback)
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]), sections[idx][1])
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]),
sections[idx][1])
else:
logging.warning("No visual model detected. Skipping figure parsing enhancement.")
@ -945,7 +958,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
has_images = merged_images and any(img is not None for img in merged_images)
if has_images:
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images, child_delimiters_pattern=child_deli))
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images,
child_delimiters_pattern=child_deli))
else:
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser, child_delimiters_pattern=child_deli))
else:
@ -955,10 +969,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if section_images:
chunks, images = naive_merge_with_images(sections, section_images,
int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
res.extend(
tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
else:
chunks = naive_merge(
sections, int(parser_config.get(
@ -993,7 +1008,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -26,6 +26,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
@ -95,14 +96,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
sections, tbls, pdf_parser = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
layout_recognizer = layout_recognizer,
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
pdf_cls=Pdf,
layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name,
**kwargs
)
@ -112,9 +113,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for (img, rows), poss in tbls:
if not rows:
continue
@ -172,7 +173,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -20,7 +20,8 @@ import re
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
from common.constants import ParserType
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks, attach_media_context
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, \
tokenize_chunks, attach_media_context
from deepdoc.parser import PdfParser
import numpy as np
from rag.app.naive import by_plaintext, PARSERS
@ -66,7 +67,7 @@ class Pdf(PdfParser):
# clean mess
if column_width < self.page_images[0].size[0] / zoomin / 2:
logging.debug("two_column................... {} {}".format(column_width,
self.page_images[0].size[0] / zoomin / 2))
self.page_images[0].size[0] / zoomin / 2))
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
@ -89,7 +90,7 @@ class Pdf(PdfParser):
title = ""
authors = []
i = 0
while i < min(32, len(self.boxes)-1):
while i < min(32, len(self.boxes) - 1):
b = self.boxes[i]
i += 1
if b.get("layoutno", "").find("title") >= 0:
@ -190,8 +191,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"tables": tables
}
tbls=paper["tables"]
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
tbls = paper["tables"]
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
paper["tables"] = tbls
else:
raise NotImplementedError("file type not supported yet(pdf supported)")
@ -329,6 +330,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -51,7 +51,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
}
)
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
ans = asyncio.run(
cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
ans += "\n" + ans
tokenize(doc, ans, eng)

View File

@ -249,7 +249,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -102,9 +102,9 @@ class Pdf(PdfParser):
self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
tbls = self._extract_table_figure(True, zoomin, True, True)
#self._naive_vertical_merge()
# self._naive_vertical_merge()
# self._concat_downward()
#self._filter_forpages()
# self._filter_forpages()
logging.debug("layouts: {}".format(timer() - start))
sections = [b["text"] for b in self.boxes]
bull_x0_list = []
@ -114,12 +114,14 @@ class Pdf(PdfParser):
qai_list = []
last_q, last_a, last_tag = '', '', ''
last_index = -1
last_box = {'text':''}
last_box = {'text': ''}
last_bull = None
def sort_key(element):
tbls_pn = element[1][0][0]
tbls_top = element[1][0][3]
return tbls_pn, tbls_top
tbls.sort(key=sort_key)
tbl_index = 0
last_pn, last_bottom = 0, 0
@ -133,28 +135,32 @@ class Pdf(PdfParser):
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
if not has_bull: # No question bullet
if not last_q:
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
tbl_index += 1
continue
else:
sum_tag = line_tag
sum_section = section
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer
while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (
tbl_pn < line_pn)): # add image at the middle of current answer
sum_tag = f'{tbl_tag}{sum_tag}'
sum_section = f'{tbl_text}{sum_section}'
tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
tbl_index)
last_a = f'{last_a}{sum_section}'
last_tag = f'{last_tag}{sum_tag}'
else:
if last_q:
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer
while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (
tbl_pn < line_pn)): # add image at the end of last answer
last_tag = f'{last_tag}{tbl_tag}'
last_a = f'{last_a}{tbl_text}'
tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
tbl_index)
image, poss = self.crop(last_tag, need_position=True)
qai_list.append((last_q, last_a, image, poss))
last_q, last_a, last_tag = '', '', ''
@ -171,7 +177,7 @@ class Pdf(PdfParser):
def get_tbls_info(self, tbls, tbl_index):
if tbl_index >= len(tbls):
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
tbl_pn = tbls[tbl_index][1][0][0]+1
tbl_pn = tbls[tbl_index][1][0][0] + 1
tbl_left = tbls[tbl_index][1][0][1]
tbl_right = tbls[tbl_index][1][0][2]
tbl_top = tbls[tbl_index][1][0][3]
@ -210,11 +216,11 @@ class Docx(DocxParser):
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = concat_img(last_image, current_image)
else: # is a question
else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
@ -240,14 +246,14 @@ class Docx(DocxParser):
tbls = []
for tb in self.doc.tables:
html= "<table>"
html = "<table>"
for r in tb.rows:
html += "<tr>"
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i+1, len(r.cells)):
for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
@ -356,7 +362,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i+1))
fails.append(str(i + 1))
elif len(arr) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
@ -429,13 +435,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if not code_block:
question_level, question = mdQuestionLevel(line)
if not question_level or question_level > 6: # not a question
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{line}'
else: # is a question
else: # is a question
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
res.append(beAdoc(deepcopy(doc), sum_question,
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = ''
i = question_level
@ -447,13 +454,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
res.append(beAdoc(deepcopy(doc), sum_question,
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res
elif re.search(r"\.docx$", filename, re.IGNORECASE):
docx_parser = Docx()
qai_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng)
for i, (q, a, image) in enumerate(qai_list):
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
@ -466,6 +474,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -64,7 +64,8 @@ def remote_call(filename, binary):
del resume[k]
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
"updated_at": datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume)
return resume
except Exception:
@ -171,6 +172,9 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -51,14 +51,15 @@ class Excel(ExcelParser):
tables = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
images = Excel._extract_images_from_worksheet(ws,sheetname=sheetname)
images = Excel._extract_images_from_worksheet(ws, sheetname=sheetname)
if images:
image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, **kwargs)
image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback,
**kwargs)
if image_descriptions and len(image_descriptions) == len(images):
for i, bf in enumerate(image_descriptions):
images[i]["image_description"] = "\n".join(bf[0][1])
for img in images:
if (img["span_type"] == "single_cell"and img.get("image_description")):
if (img["span_type"] == "single_cell" and img.get("image_description")):
pending_cell_images.append(img)
else:
flow_images.append(img)
@ -113,16 +114,17 @@ class Excel(ExcelParser):
tables.append(
(
(
img["image"], # Image.Image
[img["image_description"]] # description list (must be list)
img["image"], # Image.Image
[img["image_description"]] # description list (must be list)
),
[
(0, 0, 0, 0, 0) # dummy position
(0, 0, 0, 0, 0) # dummy position
]
)
)
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res,tables
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res, tables
def _parse_headers(self, ws, rows):
if len(rows) == 0:
@ -315,14 +317,15 @@ def trans_bool(s):
def column_data_type(arr):
arr = list(arr)
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
float_flag = False
for a in arr:
if a is None:
continue
if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
counts["int"] += 1
if int(str(a)) > 2**63 - 1:
if int(str(a)) > 2 ** 63 - 1:
float_flag = True
break
elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
@ -370,7 +373,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
dfs,tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs)
dfs, tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs)
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
@ -389,7 +392,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
continue
rows.append(row)
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
elif re.search(r"\.csv$", filename, re.IGNORECASE):
@ -406,7 +410,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
fails = []
rows = []
for i, row in enumerate(all_rows[1 + from_page : 1 + to_page]):
for i, row in enumerate(all_rows[1 + from_page: 1 + to_page]):
if len(row) != len(headers):
fails.append(str(i + from_page))
continue
@ -415,7 +419,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
callback(
0.3,
(f"Extract records: {from_page}~{from_page + len(rows)}" +
(f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))
(f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))
)
dfs = [pd.DataFrame(rows, columns=headers)]
@ -445,7 +449,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
df[clmns[j]] = cln
if ty == "text":
txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))]
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
range(len(clmns))]
eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows():
@ -477,7 +482,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -141,17 +141,20 @@ def label_question(question, kbs):
if not tag_kbs:
return tags
tags = settings.retriever.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,
kb.parser_config.get("topn_tags", 3)
)
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,
kb.parser_config.get("topn_tags", 3)
)
return tags
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -263,7 +263,7 @@ class SparkTTS(Base):
raise Exception(error)
def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志
self.audio_queue.put(None) # None is terminator
def on_open(self, ws):
def run(*args):

View File

@ -273,7 +273,7 @@ def tokenize(d, txt, eng):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
def split_with_pattern(d, pattern:str, content:str, eng) -> list:
def split_with_pattern(d, pattern: str, content: str, eng) -> list:
docs = []
txts = [txt for txt in re.split(r"(%s)" % pattern, content, flags=re.DOTALL)]
for j in range(0, len(txts), 2):
@ -281,7 +281,7 @@ def split_with_pattern(d, pattern:str, content:str, eng) -> list:
if not txt:
continue
if j + 1 < len(txts):
txt += txts[j+1]
txt += txts[j + 1]
dd = copy.deepcopy(d)
tokenize(dd, txt, eng)
docs.append(dd)
@ -304,7 +304,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern=
except NotImplementedError:
pass
else:
add_positions(d, [[ii]*5])
add_positions(d, [[ii] * 5])
if child_delimiters_pattern:
d["mom_with_weight"] = ck
@ -325,7 +325,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images, child_delimiters_patte
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
add_positions(d, [[ii]*5])
add_positions(d, [[ii] * 5])
if child_delimiters_pattern:
d["mom_with_weight"] = ck
res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng))
@ -658,7 +658,8 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
if "content_ltks" in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(combined)
if "content_sm_ltks" in ck:
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
if positioned_indices:
chunks[:] = [chunks[i] for i in ordered_indices]
@ -764,8 +765,8 @@ def not_title(txt):
return True
return re.search(r"[,;,。;!!]", txt)
def tree_merge(bull, sections, depth):
def tree_merge(bull, sections, depth):
if not sections or bull < 0:
return sections
if isinstance(sections[0], type("")):
@ -777,16 +778,17 @@ def tree_merge(bull, sections, depth):
def get_level(bull, section):
text, layout = section
text = re.sub(r"\u3000", " ", text).strip()
text = re.sub(r"\u3000", " ", text).strip()
for i, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, text.strip()):
return i+1, text
return i + 1, text
else:
if re.search(r"(title|head)", layout) and not not_title(text):
return len(BULLET_PATTERN[bull])+1, text
return len(BULLET_PATTERN[bull]) + 1, text
else:
return len(BULLET_PATTERN[bull])+2, text
return len(BULLET_PATTERN[bull]) + 2, text
level_set = set()
lines = []
for section in sections:
@ -812,8 +814,8 @@ def tree_merge(bull, sections, depth):
return [element for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0:
return []
if isinstance(sections[0], type("")):
@ -922,10 +924,10 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
@ -957,7 +959,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
return cks
for sec, pos in sections:
add_chunk("\n"+sec, pos)
add_chunk("\n" + sec, pos)
return cks
@ -978,10 +980,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
@ -1025,9 +1027,9 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if isinstance(text, tuple):
text_str = text[0]
text_pos = text[1] if len(text) > 1 else ""
add_chunk("\n"+text_str, image, text_pos)
add_chunk("\n" + text_str, image, text_pos)
else:
add_chunk("\n"+text, image)
add_chunk("\n" + text, image)
return cks, result_images
@ -1042,7 +1044,7 @@ def docx_question_level(p, bull=-1):
for j, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, txt):
return j + 1, txt
return len(BULLET_PATTERN[bull])+1, txt
return len(BULLET_PATTERN[bull]) + 1, txt
def concat_img(img1, img2):
@ -1211,7 +1213,7 @@ class Node:
child = node.get_children()
if level == 0 and texts:
tree_list.append("\n".join(titles+texts))
tree_list.append("\n".join(titles + texts))
# Titles within configured depth are accumulated into the current path
if 1 <= level <= self.depth:

View File

@ -205,11 +205,11 @@ class FulltextQueryer(QueryBase):
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
s += v #* dtwt[k]
s += v # * dtwt[k]
q = 1e-9
for k, v in qtwt.items():
q += v #* v
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
q += v # * v
return s / q # math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
@ -232,4 +232,5 @@ class FulltextQueryer(QueryBase):
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) / 10), "original_query": " ".join(origin_keywords)})
{"minimum_should_match": min(3, len(keywords) / 10),
"original_query": " ".join(origin_keywords)})

View File

@ -66,7 +66,8 @@ class Dealer:
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
"removed_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
@ -141,7 +142,8 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
@ -218,8 +220,9 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
chunk_v[i] = [0.0] * len(ans_v[0])
logging.warning(
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
@ -273,7 +276,7 @@ class Dealer:
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD):
@ -286,8 +289,8 @@ class Dealer:
if denor == 0:
rank_fea.append(0)
else:
rank_fea.append(nor/np.sqrt(denor)/q_denor)
return np.array(rank_fea)*10. + pageranks
rank_fea.append(nor / np.sqrt(denor) / q_denor)
return np.array(rank_fea) * 10. + pageranks
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
@ -358,21 +361,21 @@ class Dealer:
rag_tokenizer.tokenize(inst).split())
def retrieval(
self,
question,
embd_mdl,
tenant_ids,
kb_ids,
page,
page_size,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
top=1024,
doc_ids=None,
aggs=True,
rerank_mdl=None,
highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10},
self,
question,
embd_mdl,
tenant_ids,
kb_ids,
page,
page_size,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
top=1024,
doc_ids=None,
aggs=True,
rerank_mdl=None,
highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10},
):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
@ -395,7 +398,8 @@ class Dealer:
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
rank_feature=rank_feature)
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(
@ -558,13 +562,14 @@ class Dealer:
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True
@ -580,11 +585,11 @@ class Dealer:
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
if not chunks:
return []
idx_nms = [index_name(tid) for tid in tenant_ids]
@ -594,9 +599,10 @@ class Dealer:
ranks[ck["doc_id"]] = 0
ranks[ck["doc_id"]] += ck["similarity"]
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
kb_ids = [doc_id2kb_id[doc_id]]
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
@ -608,7 +614,7 @@ class Dealer:
if not toc:
return chunks
ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2))
ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2))
if not ids:
return chunks
@ -644,9 +650,9 @@ class Dealer:
break
chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]
def retrieval_by_children(self, chunks:list[dict], tenant_ids:list[str]):
def retrieval_by_children(self, chunks: list[dict], tenant_ids: list[str]):
if not chunks:
return []
idx_nms = [index_name(tid) for tid in tenant_ids]
@ -692,4 +698,4 @@ class Dealer:
break
chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)
return sorted(chunks, key=lambda x: x["similarity"] * -1)

View File

@ -14,129 +14,131 @@
# limitations under the License.
#
m = set(["","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","羿","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","宿","","怀",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","寿","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"广","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","西","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","鹿","",
"万俟","司马","上官","欧阳",
"夏侯","诸葛","闻人","东方",
"赫连","皇甫","尉迟","公羊",
"澹台","公冶","宗政","濮阳",
"淳于","单于","太叔","申屠",
"公孙","仲孙","轩辕","令狐",
"钟离","宇文","长孙","慕容",
"鲜于","闾丘","司徒","司空",
"亓官","司寇","仉督","子车",
"颛孙","端木","巫马","公西",
"漆雕","乐正","壤驷","公良",
"拓跋","夹谷","宰父","榖梁",
"","","","","","","","",
"段干","百里","东郭","南门",
"呼延","","","羊舌","","",
"","","","","","","","",
"梁丘","左丘","东门","西门",
"","","","","","","南宫",
"","","","","","","","",
"第五","",""])
m = set(["", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "羿", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "宿", "", "怀",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "寿", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"广", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "",
"", "", "", "西", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "鹿", "",
"万俟", "司马", "上官", "欧阳",
"夏侯", "诸葛", "闻人", "东方",
"赫连", "皇甫", "尉迟", "公羊",
"澹台", "公冶", "宗政", "濮阳",
"淳于", "单于", "太叔", "申屠",
"公孙", "仲孙", "轩辕", "令狐",
"钟离", "宇文", "长孙", "慕容",
"鲜于", "闾丘", "司徒", "司空",
"亓官", "司寇", "仉督", "子车",
"颛孙", "端木", "巫马", "公西",
"漆雕", "乐正", "壤驷", "公良",
"拓跋", "夹谷", "宰父", "榖梁",
"", "", "", "", "", "", "", "",
"段干", "百里", "东郭", "南门",
"呼延", "", "", "羊舌", "", "",
"", "", "", "", "", "", "", "",
"梁丘", "左丘", "东门", "西门",
"", "", "", "", "", "", "南宫",
"", "", "", "", "", "", "", "",
"第五", "", ""])
def isit(n):return n.strip() in m
def isit(n): return n.strip() in m

View File

@ -1,4 +1,4 @@
#
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -108,13 +108,14 @@ class Dealer:
if re.match(p, t):
tk = "#"
break
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
# tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk:
res.append(tk)
return res
def token_merge(self, tks):
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
def one_term(t):
return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
@ -152,8 +153,8 @@ class Dealer:
tks = []
for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t
else:
tks.append(t)
@ -220,14 +221,15 @@ class Dealer:
return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
def idf(s, N):
return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = []
if not preprocess:
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tks])
np.array([ner(t) * postag(t) for t in tks])
wts = [s for s in wts]
tw = list(zip(tks, wts))
else:
@ -236,7 +238,7 @@ class Dealer:
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tt])
np.array([ner(t) * postag(t) for t in tt])
wts = [s for s in wts]
tw.extend(zip(tt, wts))

View File

@ -3,4 +3,4 @@ from . import generator
__all__ = [name for name in dir(generator)
if not name.startswith('_')]
globals().update({name: getattr(generator, name) for name in __all__})
globals().update({name: getattr(generator, name) for name in __all__})

View File

@ -28,17 +28,16 @@ from rag.prompts.template import load_prompt
from common.constants import TAG_FLD
from common.token_utils import encoder, num_tokens_from_string
STOP_TOKEN="<|STOP|>"
COMPLETE_TASK="complete_task"
STOP_TOKEN = "<|STOP|>"
COMPLETE_TASK = "complete_task"
INPUT_UTILIZATION = 0.5
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
def chunks_format(reference):
return [
{
"id": get_value(chunk, "chunk_id", "id"),
@ -126,7 +125,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
cnt += draw_node(k, v)
cnt += "\n└── Content:\n"
@ -173,7 +172,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
def citation_prompt(user_defined_prompts: dict={}) -> str:
def citation_prompt(user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
return template.render()
@ -258,9 +257,11 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,
languages=languages)
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}],
{"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
@ -332,7 +333,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": {
"type": "object",
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
"properties": {
"answer": {"type": "string", "description": "The final answer to the user's question"}},
"required": ["answer"]
}
}
@ -341,7 +343,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
name = tool["function"]["name"]
desc[name] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
return "\n\n".join([f"## {i + 1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in
enumerate(desc.items())])
def form_history(history, limit=-6):
@ -350,14 +353,14 @@ def form_history(history, limit=-6):
if h["role"] == "system":
continue
role = "USER"
if h["role"].upper()!= role:
if h["role"].upper() != role:
role = "AGENT"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content']) > 2048 else '')}"
return context
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict],
user_defined_prompts: dict = {}):
tools_desc = tool_schema(tools_description)
context = ""
@ -375,7 +378,8 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
async def next_step_async(chat_mdl, history: list, tools_description: list[dict], task_desc,
user_defined_prompts: dict = {}):
if not tools_description:
return "", 0
desc = tool_schema(tools_description)
@ -396,7 +400,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
return json_str, tk_cnt
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict = {}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@ -419,7 +423,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
def form_message(system_prompt, user_prompt):
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
return [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
def structured_output_prompt(schema=None) -> str:
@ -427,27 +431,29 @@ def structured_output_prompt(schema=None) -> str:
return template.render(schema=schema)
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summaries: list[str],
user_defined_prompts: dict = {}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
system_prompt = template.render(goal=goal, sub_goal=sub_goal,
results=[{"i": i, "content": s} for i, s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict:
meta_data_structure = {}
for key, values in meta_data.items():
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
@ -471,13 +477,13 @@ async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
return {"conditions": []}
async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf=None):
from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
res = json_repair.loads(ans)
@ -488,10 +494,13 @@ async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None
TOC_DETECTION = load_prompt("toc_detection")
async def detect_table_of_contents(page_1024:list[str], chat_mdl):
async def detect_table_of_contents(page_1024: list[str], chat_mdl):
toc_secs = []
for i, sec in enumerate(page_1024[:22]):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.",
chat_mdl)
if toc_secs and not ans["exists"]:
break
toc_secs.append(sec)
@ -500,14 +509,17 @@ async def detect_table_of_contents(page_1024:list[str], chat_mdl):
TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
async def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages:
return []
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)),
"Only JSON please.", chat_mdl)
async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
async def toc_index_extractor(toc: list[dict], content: str, chat_mdl):
tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
@ -529,18 +541,21 @@ async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False,
indent=2) + '\nDocument pages:\n' + content
return await gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index")
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections:
return []
toc_map = {}
for i, it in enumerate(toc_arr):
k1 = (it["structure"]+it["title"]).replace(" ", "")
k1 = (it["structure"] + it["title"]).replace(" ", "")
k2 = it["title"].strip()
if k1 not in toc_map:
toc_map[k1] = []
@ -558,6 +573,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
toc_arr[j]["indices"].append(i)
all_pathes = []
def dfs(start, path):
nonlocal all_pathes
if start >= len(toc_arr):
@ -565,7 +581,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
all_pathes.append(path)
return
if not toc_arr[start]["indices"]:
dfs(start+1, path)
dfs(start + 1, path)
return
added = False
for j in toc_arr[start]["indices"]:
@ -574,12 +590,12 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
_path = deepcopy(path)
_path.append((j, start))
added = True
dfs(start+1, _path)
dfs(start + 1, _path)
if not added and path:
all_pathes.append(path)
dfs(0, [])
path = max(all_pathes, key=lambda x:len(x))
path = max(all_pathes, key=lambda x: len(x))
for it in toc_arr:
it["indices"] = []
for j, i in path:
@ -588,24 +604,24 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
i = 0
while i < len(toc_arr):
it = toc_arr[i]
it = toc_arr[i]
if it["indices"]:
i += 1
continue
if i>0 and toc_arr[i-1]["indices"]:
st_i = toc_arr[i-1]["indices"][-1]
if i > 0 and toc_arr[i - 1]["indices"]:
st_i = toc_arr[i - 1]["indices"][-1]
else:
st_i = 0
e = i + 1
while e <len(toc_arr) and not toc_arr[e]["indices"]:
while e < len(toc_arr) and not toc_arr[e]["indices"]:
e += 1
if e >= len(toc_arr):
e = len(sections)
else:
e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))):
for j in range(st_i, min(e + 1, len(sections))):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"],
title=it["title"],
@ -656,11 +672,15 @@ async def toc_transformer(toc_pages, chat_mdl):
toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
def clean_toc(arr):
for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False, indent=2),
chat_mdl)
clean_toc(last_complete)
if if_complete == "yes":
return last_complete
@ -682,13 +702,17 @@ async def toc_transformer(toc_pages, chat_mdl):
break
clean_toc(new_complete)
last_complete.extend(new_complete)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False,
indent=2), chat_mdl)
return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels")
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}):
if not toc_secs:
return []
return await gen_json(
@ -701,12 +725,15 @@ async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2})
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try:
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(
text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
@ -743,7 +770,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)
input_budget = 1024 if input_budget > 1024 else input_budget
input_budget = 1024 if input_budget > 1024 else input_budget
chunk_sections = split_chunks(chunks, input_budget)
titles = []
@ -798,7 +825,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
if sorted_list:
max_lvl = sorted_list[-1]
merged = []
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
for _, (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
if prune and toc_item.get("level", "0") >= max_lvl:
continue
merged.append({
@ -812,12 +839,15 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
import numpy as np
try:
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n" % "\n".join(
[json.dumps({"level": d["level"], "title": d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
@ -828,17 +858,19 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
for id in ti.get("ids", []):
if id not in id2score:
id2score[id] = []
id2score[id].append(sc["score"]/5.)
id2score[id].append(sc["score"] / 5.)
for id in id2score.keys():
id2score[id] = np.mean(id2score[id])
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
return [(id, sc) for id, sc in list(id2score.items()) if sc >= 0.3][:topn]
except Exception as e:
logging.exception(e)
return []
META_DATA = load_prompt("meta_data")
async def gen_metadata(chat_mdl, schema:dict, content:str):
async def gen_metadata(chat_mdl, schema: dict, content: str):
template = PROMPT_JINJA_ENV.from_string(META_DATA)
for k, desc in schema["properties"].items():
if "enum" in desc and not desc.get("enum"):
@ -849,4 +881,4 @@ async def gen_metadata(chat_mdl, schema:dict, content:str):
user_prompt = "Output: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)

View File

@ -1,6 +1,5 @@
import os
PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {}

View File

@ -48,13 +48,15 @@ def main():
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))
except Exception as e:
traceback.print_stack(e)
logging.error(f"Error to get data from REDIS: {e}")
traceback.print_stack()
except Exception as e:
traceback.print_stack(e)
logging.error(f"Error to check REDIS connection: {e}")
traceback.print_stack()
if __name__ == "__main__":
while True:
main()
close_connection()
time.sleep(1)
time.sleep(1)

View File

@ -19,16 +19,15 @@ import requests
import base64
import asyncio
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
JSON_DATA = {
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
"word": "" # User question, don't need to initialize
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
"word": "" # User question, don't need to initialize
}
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application
intents = discord.Intents.default()
intents.message_content = True
@ -50,7 +49,7 @@ async def on_message(message):
if len(message.content.split('> ')) == 1:
await message.channel.send("Hi~ How can I help you? ")
else:
JSON_DATA['word']=message.content.split('> ')[1]
JSON_DATA['word'] = message.content.split('> ')[1]
response = requests.post(URL, json=JSON_DATA)
response_data = response.json().get('data', [])
image_bool = False
@ -61,9 +60,9 @@ async def on_message(message):
if i['type'] == 3:
image_bool = True
image_data = base64.b64decode(i['url'])
with open('tmp_image.png','wb') as file:
with open('tmp_image.png', 'wb') as file:
file.write(image_data)
image= discord.File('tmp_image.png')
image = discord.File('tmp_image.png')
await message.channel.send(f"{message.author.mention}{res}")

View File

@ -38,7 +38,8 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings
from common.config_utils import show_configs
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, \
MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
from common.constants import FileSource, TaskStatus
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.confluence_connector import ConfluenceConnector
@ -96,7 +97,7 @@ class SyncBase:
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for document_batch in document_batch_generator:
for document_batch in document_batch_generator:
if not document_batch:
continue
@ -161,6 +162,7 @@ class SyncBase:
def _get_source_prefix(self):
return ""
class _BlobLikeBase(SyncBase):
DEFAULT_BUCKET_TYPE: str = "s3"
@ -199,22 +201,27 @@ class _BlobLikeBase(SyncBase):
)
return document_batch_generator
class S3(_BlobLikeBase):
SOURCE_NAME: str = FileSource.S3
DEFAULT_BUCKET_TYPE: str = "s3"
class R2(_BlobLikeBase):
SOURCE_NAME: str = FileSource.R2
DEFAULT_BUCKET_TYPE: str = "r2"
class OCI_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.OCI_STORAGE
DEFAULT_BUCKET_TYPE: str = "oci_storage"
class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
class Confluence(SyncBase):
SOURCE_NAME: str = FileSource.CONFLUENCE
@ -248,7 +255,9 @@ class Confluence(SyncBase):
index_recursively=index_recursively,
)
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json=self.conf["credentials"])
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
@ -280,7 +289,8 @@ class Confluence(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator:
if failure is not None:
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
logging.warning("Confluence connector failure: %s",
getattr(failure, "failure_message", failure))
continue
if document is not None:
pending_docs.append(document)
@ -300,7 +310,7 @@ class Confluence(SyncBase):
async def async_wrapper():
for batch in document_batches():
yield batch
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
return async_wrapper()
@ -314,10 +324,12 @@ class Notion(SyncBase):
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
)
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
return document_generator
@ -340,10 +352,12 @@ class Discord(SyncBase):
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
)
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
return document_generator
@ -485,7 +499,8 @@ class GoogleDrive(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator:
if failure is not None:
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
logging.warning("Google Drive connector failure: %s",
getattr(failure, "failure_message", failure))
continue
if document is not None:
pending_docs.append(document)
@ -646,10 +661,10 @@ class WebDAV(SyncBase):
remote_path=self.conf.get("remote_path", "/")
)
self.connector.load_credentials(self.conf["credentials"])
logging.info(f"Task info: reindex={task['reindex']}, poll_range_start={task['poll_range_start']}")
if task["reindex"]=="1" or not task["poll_range_start"]:
if task["reindex"] == "1" or not task["poll_range_start"]:
logging.info("Using load_from_state (full sync)")
document_batch_generator = self.connector.load_from_state()
begin_info = "totally"
@ -659,14 +674,15 @@ class WebDAV(SyncBase):
logging.info(f"Polling WebDAV from {task['poll_range_start']} (ts: {start_ts}) to now (ts: {end_ts})")
document_batch_generator = self.connector.poll_source(start_ts, end_ts)
begin_info = "from {}".format(task["poll_range_start"])
logging.info("Connect to WebDAV: {}(path: {}) {}".format(
self.conf["base_url"],
self.conf.get("remote_path", "/"),
begin_info
))
return document_batch_generator
class Moodle(SyncBase):
SOURCE_NAME: str = FileSource.MOODLE
@ -675,7 +691,7 @@ class Moodle(SyncBase):
moodle_url=self.conf["moodle_url"],
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)
)
self.connector.load_credentials(self.conf["credentials"])
# Determine the time range for synchronization based on reindex or poll_range_start
@ -689,7 +705,7 @@ class Moodle(SyncBase):
begin_info = "totally"
else:
document_generator = self.connector.poll_source(
poll_start.timestamp(),
poll_start.timestamp(),
datetime.now(timezone.utc).timestamp()
)
begin_info = "from {}".format(poll_start)
@ -718,7 +734,7 @@ class BOX(SyncBase):
token = AccessToken(
access_token=credential['access_token'],
refresh_token=credential['refresh_token'],
)
)
auth.token_storage.store(token)
self.connector.load_credentials(auth)
@ -739,6 +755,7 @@ class BOX(SyncBase):
logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info))
return document_generator
class Airtable(SyncBase):
SOURCE_NAME: str = FileSource.AIRTABLE
@ -784,6 +801,7 @@ class Airtable(SyncBase):
return document_generator
func_factory = {
FileSource.S3: S3,
FileSource.R2: R2,

View File

@ -92,7 +92,7 @@ FACTORY = {
}
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE,
"dataflow": PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
"mindmap": PipelineTaskType.MINDMAP,
@ -221,7 +221,7 @@ async def get_storage_binary(bucket, name):
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60*80, 1)
@timeout(60 * 80, 1)
async def build_chunks(task, progress_callback):
if task["size"] > settings.DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
@ -283,7 +283,8 @@ async def build_chunks(task, progress_callback):
try:
d = copy.deepcopy(document)
d.update(chunk)
d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["id"] = xxhash.xxh64(
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
@ -328,9 +329,11 @@ async def build_chunks(task, progress_callback):
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
tasks.append(
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
@ -355,9 +358,11 @@ async def build_chunks(task, progress_callback):
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
tasks.append(
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
@ -374,15 +379,18 @@ async def build_chunks(task, progress_callback):
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
async def gen_metadata_task(chat_mdl, d):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", task["parser_config"]["metadata"])
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
task["parser_config"]["metadata"])
if not cached:
async with chat_limiter:
cached = await gen_metadata(chat_mdl,
metadata_schema(task["parser_config"]["metadata"]),
d["content_with_weight"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", task["parser_config"]["metadata"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
task["parser_config"]["metadata"])
if cached:
d["metadata_obj"] = cached
tasks = []
for d in docs:
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
@ -430,7 +438,8 @@ async def build_chunks(task, progress_callback):
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
docs_to_tag.append(d)
@ -438,7 +447,7 @@ async def build_chunks(task, progress_callback):
async def doc_content_tagging(chat_mdl, d, topn_tags):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter:
@ -454,6 +463,7 @@ async def build_chunks(task, progress_callback):
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
tasks = []
for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
@ -473,21 +483,22 @@ async def build_chunks(task, progress_callback):
def build_TOC(task, docs, progress_callback):
progress_callback(msg="Start to generate table of content ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
docs = sorted(docs, key=lambda d:(
docs = sorted(docs, key=lambda d: (
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
toc: list[dict] = asyncio.run(
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
if ii == len(toc) - 1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
for jj in range(idx + 1, int(toc[ii + 1]["chunk_id"]) + 1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
@ -499,7 +510,8 @@ def build_TOC(task, docs, progress_callback):
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["id"] = xxhash.xxh64(
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
@ -532,12 +544,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
@timeout(60)
def batch_encode(txts):
nonlocal mdl
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
return mdl.encode([truncate(c, mdl.max_length - 10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0:
cnts_ = vts
else:
@ -545,7 +557,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
cnts = cnts_
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
@ -588,7 +600,8 @@ async def run_dataflow(task: dict):
return
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
@ -610,25 +623,27 @@ async def run_dataflow(task: dict):
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1)
delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1:
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}")
if i % (len(texts) // settings.EMBEDDING_BATCH_SIZE / 100 + 1) == 1:
set_progress(task_id, prog=prog, msg=f"{i + 1} / {len(texts) // settings.EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
@ -636,10 +651,10 @@ async def run_dataflow(task: dict):
ck["q_%d_vec" % len(v)] = v
except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
metadata = {}
for ck in chunks:
ck["doc_id"] = doc_id
@ -686,15 +701,19 @@ async def run_dataflow(task: dict):
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
task_time_cost)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
dsl=str(pipeline))
@timeout(3600)
@ -702,7 +721,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
vctr_nm = "q_%d_vec"%vector_size
vctr_nm = "q_%d_vec" % vector_size
res = []
tk_count = 0
@ -747,17 +766,17 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for x, doc_id in enumerate(doc_ids):
chunks = []
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, doc_id)
callback(prog=(x+1.)/len(doc_ids))
callback(prog=(x + 1.) / len(doc_ids))
else:
chunks = []
for doc_id in doc_ids:
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, fake_doc_id)
@ -792,19 +811,22 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mom_ck["available_int"] = 0
flds = list(mom_ck.keys())
for fld in flds:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int"]:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
"position_int"]:
del mom_ck[fld]
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -821,7 +843,8 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
search.index_name(task_tenant_id), task_dataset_id, )
tasks = []
for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
@ -838,7 +861,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
return True
@timeout(60*60*3, 1)
@timeout(60 * 60 * 3, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
@ -914,7 +937,7 @@ async def do_handle_task(task):
},
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
@ -943,7 +966,7 @@ async def do_handle_task(task):
doc_ids=task.get("doc_ids", []),
)
if fake_doc_ids := task.get("doc_ids", []):
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
# Either using graphrag or Standard chunking methods
elif task_type == "graphrag":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
@ -968,11 +991,10 @@ async def do_handle_task(task):
}
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
@ -1030,7 +1052,7 @@ async def do_handle_task(task):
return True
e = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(e)
try:
if not await _maybe_insert_es(chunks):
return
@ -1084,8 +1106,8 @@ async def do_handle_task(task):
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled."
)
async def handle_task():
async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
@ -1093,7 +1115,8 @@ async def handle_task():
return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
@ -1119,7 +1142,9 @@ async def handle_task():
if task_type in ["graphrag", "raptor", "mindmap"]:
task_document_ids = task["doc_ids"]
if not task.get("dataflow_id", ""):
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
task_type=pipeline_task_type,
fake_document_ids=task_document_ids)
redis_msg.ack()
@ -1249,6 +1274,7 @@ async def main():
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)

View File

@ -42,8 +42,10 @@ class RAGFlowAzureSpnBlob:
pass
try:
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials)
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id,
client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name,
credential=credentials)
except Exception:
logging.exception("Fail to connect %s" % self.account_url)
@ -104,4 +106,4 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return None
return None

View File

@ -25,7 +25,8 @@ from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str = "imagetemps"):
import logging
from io import BytesIO
from rag.svr.task_executor import minio_limiter
@ -74,7 +75,7 @@ async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="
del d["image"]
def id2image(image_id:str|None, storage_get_func: partial):
def id2image(image_id: str | None, storage_get_func: partial):
if not image_id:
return
arr = image_id.split("-")

View File

@ -16,11 +16,13 @@
import logging
from common.crypto_utils import CryptoUtil
# from common.decorator import singleton
class EncryptedStorageWrapper:
"""Encrypted storage wrapper that wraps existing storage implementations to provide transparent encryption"""
def __init__(self, storage_impl, algorithm="aes-256-cbc", key=None, iv=None):
"""
Initialize encrypted storage wrapper
@ -34,16 +36,16 @@ class EncryptedStorageWrapper:
self.storage_impl = storage_impl
self.crypto = CryptoUtil(algorithm=algorithm, key=key, iv=iv)
self.encryption_enabled = True
# Check if storage implementation has required methods
# todo: Consider abstracting a storage base class to ensure these methods exist
required_methods = ["put", "get", "rm", "obj_exist", "health"]
for method in required_methods:
if not hasattr(storage_impl, method):
raise AttributeError(f"Storage implementation missing required method: {method}")
logging.info(f"EncryptedStorageWrapper initialized with algorithm: {algorithm}")
def put(self, bucket, fnm, binary, tenant_id=None):
"""
Encrypt and store data
@ -59,15 +61,15 @@ class EncryptedStorageWrapper:
"""
if not self.encryption_enabled:
return self.storage_impl.put(bucket, fnm, binary, tenant_id)
try:
encrypted_binary = self.crypto.encrypt(binary)
return self.storage_impl.put(bucket, fnm, encrypted_binary, tenant_id)
except Exception as e:
logging.exception(f"Failed to encrypt and store data: {bucket}/{fnm}, error: {str(e)}")
raise
def get(self, bucket, fnm, tenant_id=None):
"""
Retrieve and decrypt data
@ -83,21 +85,21 @@ class EncryptedStorageWrapper:
try:
# Get encrypted data
encrypted_binary = self.storage_impl.get(bucket, fnm, tenant_id)
if encrypted_binary is None:
return None
if not self.encryption_enabled:
return encrypted_binary
# Decrypt data
decrypted_binary = self.crypto.decrypt(encrypted_binary)
return decrypted_binary
except Exception as e:
logging.exception(f"Failed to get and decrypt data: {bucket}/{fnm}, error: {str(e)}")
raise
def rm(self, bucket, fnm, tenant_id=None):
"""
Delete data (same as original storage implementation, no decryption needed)
@ -111,7 +113,7 @@ class EncryptedStorageWrapper:
Deletion result
"""
return self.storage_impl.rm(bucket, fnm, tenant_id)
def obj_exist(self, bucket, fnm, tenant_id=None):
"""
Check if object exists (same as original storage implementation, no decryption needed)
@ -125,7 +127,7 @@ class EncryptedStorageWrapper:
Whether the object exists
"""
return self.storage_impl.obj_exist(bucket, fnm, tenant_id)
def health(self):
"""
Health check (uses the original storage implementation's method)
@ -134,7 +136,7 @@ class EncryptedStorageWrapper:
Health check result
"""
return self.storage_impl.health()
def bucket_exists(self, bucket):
"""
Check if bucket exists (if the original storage implementation has this method)
@ -148,7 +150,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "bucket_exists"):
return self.storage_impl.bucket_exists(bucket)
return False
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
"""
Get presigned URL (if the original storage implementation has this method)
@ -165,7 +167,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "get_presigned_url"):
return self.storage_impl.get_presigned_url(bucket, fnm, expires, tenant_id)
return None
def scan(self, bucket, fnm, tenant_id=None):
"""
Scan objects (if the original storage implementation has this method)
@ -181,7 +183,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "scan"):
return self.storage_impl.scan(bucket, fnm, tenant_id)
return None
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
"""
Copy object (if the original storage implementation has this method)
@ -198,7 +200,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "copy"):
return self.storage_impl.copy(src_bucket, src_path, dest_bucket, dest_path)
return False
def move(self, src_bucket, src_path, dest_bucket, dest_path):
"""
Move object (if the original storage implementation has this method)
@ -215,7 +217,7 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "move"):
return self.storage_impl.move(src_bucket, src_path, dest_bucket, dest_path)
return False
def remove_bucket(self, bucket):
"""
Remove bucket (if the original storage implementation has this method)
@ -229,17 +231,18 @@ class EncryptedStorageWrapper:
if hasattr(self.storage_impl, "remove_bucket"):
return self.storage_impl.remove_bucket(bucket)
return False
def enable_encryption(self):
"""Enable encryption"""
self.encryption_enabled = True
logging.info("Encryption enabled")
def disable_encryption(self):
"""Disable encryption"""
self.encryption_enabled = False
logging.info("Encryption disabled")
# Create singleton wrapper function
def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True):
"""
@ -255,12 +258,12 @@ def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_
Encrypted storage wrapper instance
"""
wrapper = EncryptedStorageWrapper(storage_impl, algorithm=algorithm, key=key)
wrapper.encryption_enabled = encryption_enabled
if encryption_enabled:
logging.info("Encryption enabled in storage wrapper")
else:
logging.info("Encryption disabled in storage wrapper")
return wrapper

View File

@ -32,7 +32,6 @@ ATTEMPT_TIME = 2
@singleton
class ESConnection(ESConnectionBase):
"""
CRUD operations
"""
@ -82,8 +81,9 @@ class ESConnection(ESConnectionBase):
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
@ -93,9 +93,9 @@ class ESConnection(ESConnectionBase):
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
@ -146,7 +146,7 @@ class ESConnection(ESConnectionBase):
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
# print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names,
body=q,
timeout="600s",
@ -220,13 +220,15 @@ class ESConnection(ESConnectionBase):
try:
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=index_name, id=chunk_id, doc=doc)
return True
except Exception as e:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(
e))
break
return False

View File

@ -25,18 +25,23 @@ import PyPDF2
from docx import Document
import olefile
def _is_zip(h: bytes) -> bool:
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
def _is_pdf(h: bytes) -> bool:
return h.startswith(b"%PDF-")
def _is_ole(h: bytes) -> bool:
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
def _sha10(b: bytes) -> str:
return hashlib.sha256(b).hexdigest()[:10]
def _guess_ext(b: bytes) -> str:
h = b[:8]
if _is_zip(h):
@ -58,13 +63,14 @@ def _guess_ext(b: bytes) -> str:
return ".doc"
return ".bin"
# Try to extract the real embedded payload from OLE's Ole10Native
def _extract_ole10native_payload(data: bytes) -> bytes:
try:
pos = 0
if len(data) < 4:
return data
_ = int.from_bytes(data[pos:pos+4], "little")
_ = int.from_bytes(data[pos:pos + 4], "little")
pos += 4
# filename/src/tmp (NUL-terminated ANSI)
for _ in range(3):
@ -74,14 +80,15 @@ def _extract_ole10native_payload(data: bytes) -> bytes:
pos += 4
if pos + 4 > len(data):
return data
size = int.from_bytes(data[pos:pos+4], "little")
size = int.from_bytes(data[pos:pos + 4], "little")
pos += 4
if pos + size <= len(data):
return data[pos:pos+size]
return data[pos:pos + size]
except Exception:
pass
return data
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
"""
Only extract the 'first layer' of embedding, returning raw (filename, bytes).
@ -163,7 +170,7 @@ def extract_links_from_docx(docx_bytes: bytes):
# Each relationship may represent a hyperlink, image, footer, etc.
for rel in document.part.rels.values():
if rel.reltype == (
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
):
links.add(rel.target_ref)
@ -198,6 +205,8 @@ def extract_links_from_pdf(pdf_bytes: bytes):
_GLOBAL_SESSION: Optional[requests.Session] = None
def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
"""Get or create a global reusable session."""
global _GLOBAL_SESSION
@ -216,10 +225,10 @@ def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
def extract_html(
url: str,
timeout: float = 60.0,
headers: Optional[Dict[str, str]] = None,
max_retries: int = 2,
url: str,
timeout: float = 60.0,
headers: Optional[Dict[str, str]] = None,
max_retries: int = 2,
) -> Tuple[Optional[bytes], Dict[str, str]]:
"""
Extract the full HTML page as raw bytes from a given URL.
@ -260,4 +269,4 @@ def extract_html(
metadata["error"] = f"Request failed: {e}"
continue
return None, metadata
return None, metadata

View File

@ -204,4 +204,4 @@ class RAGFlowGCS:
return False
except Exception:
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
return False
return False

View File

@ -28,7 +28,6 @@ from common.doc_store.infinity_conn_base import InfinityConnectionBase
@singleton
class InfinityConnection(InfinityConnectionBase):
"""
Dataframe and fields convert
"""
@ -83,24 +82,23 @@ class InfinityConnection(InfinityConnectionBase):
tokens[0] = field
return "^".join(tokens)
"""
CRUD operations
"""
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
@ -159,7 +157,8 @@ class InfinityConnection(InfinityConnectionBase):
if table_found:
break
if not table_found:
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
self.logger.error(
f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
return pd.DataFrame(), 0
for matchExpr in match_expressions:
@ -280,7 +279,8 @@ class InfinityConnection(InfinityConnectionBase):
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
self.logger.warning(
f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
@ -288,7 +288,9 @@ class InfinityConnection(InfinityConnectionBase):
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd",
"question_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks",
"authors_sm_tks"]:
fields.add(field)
res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunk_id, None)
@ -379,7 +381,9 @@ class InfinityConnection(InfinityConnectionBase):
d[k] = "_".join(f"{num:08x}" for num in v)
else:
d[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd",
"question_tks"]:
if k in d:
del d[k]
@ -478,7 +482,8 @@ class InfinityConnection(InfinityConnectionBase):
del new_value[k]
else:
new_value[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in new_value:
del new_value[k]
@ -502,7 +507,8 @@ class InfinityConnection(InfinityConnectionBase):
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv)
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])),
{k: "###".join(v)})
table_instance.update(filter, new_value)
self.connPool.release_conn(inf_conn)
@ -561,7 +567,7 @@ class InfinityConnection(InfinityConnectionBase):
def to_position_int(v):
if v:
arr = [int(hex_val, 16) for hex_val in v.split("_")]
v = [arr[i : i + 5] for i in range(0, len(arr), 5)]
v = [arr[i: i + 5] for i in range(0, len(arr), 5)]
else:
v = []
return v

View File

@ -46,6 +46,7 @@ class RAGFlowMinio:
# pass original identifier forward for use by other decorators
kwargs['_orig_bucket'] = original_bucket
return method(self, actual_bucket, *args, **kwargs)
return wrapper
@staticmethod
@ -71,6 +72,7 @@ class RAGFlowMinio:
fnm = f"{orig_bucket}/{fnm}"
return method(self, bucket, fnm, *args, **kwargs)
return wrapper
def __open__(self):

View File

@ -37,7 +37,8 @@ from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.float_utils import get_float
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
MatchDenseExpr
from rag.nlp import rag_tokenizer
ATTEMPT_TIME = 2
@ -719,19 +720,19 @@ class OBConnection(DocStoreConnection):
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None,
**kwargs,
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None,
**kwargs,
):
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
@ -1546,7 +1547,7 @@ class OBConnection(DocStoreConnection):
flags=re.IGNORECASE | re.MULTILINE,
)
if len(re.findall(r'</em><em>', highlighted_txt)) > 0 or len(
re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
return highlighted_txt
else:
return None
@ -1565,9 +1566,9 @@ class OBConnection(DocStoreConnection):
if token_pos != -1:
if token in keywords:
highlighted_txt = (
highlighted_txt[:token_pos] +
f'<em>{token}</em>' +
highlighted_txt[token_pos + len(token):]
highlighted_txt[:token_pos] +
f'<em>{token}</em>' +
highlighted_txt[token_pos + len(token):]
)
last_pos = token_pos
return re.sub(r'</em><em>', '', highlighted_txt)

View File

@ -6,7 +6,6 @@ from urllib.parse import quote_plus
from common.config_utils import get_base_config
from common.decorator import singleton
CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS `{}` (
`key` VARCHAR(255) PRIMARY KEY,
@ -36,7 +35,8 @@ def get_opendal_config():
"table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"),
"max_allowed_packet": str(max_packet)
}
kwargs["connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
kwargs[
"connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
else:
scheme = opendal_config.get("scheme")
config_data = opendal_config.get("config", {})
@ -61,7 +61,7 @@ def get_opendal_config():
del kwargs["password"]
if "connection_string" in kwargs:
del kwargs["connection_string"]
return kwargs
return kwargs
except Exception as e:
logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))
raise
@ -99,7 +99,6 @@ class OpenDALStorage:
def obj_exist(self, bucket, fnm, tenant_id=None):
return self._operator.exists(f"{bucket}/{fnm}")
def init_db_config(self):
try:
conn = pymysql.connect(

View File

@ -26,7 +26,8 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
@ -189,7 +190,7 @@ class OSConnection(DocStoreConnection):
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
# Elasticsearch has the encapsulation of KNN_search in python sdk
# while the Python SDK for OpenSearch does not provide encapsulation for KNN_search,
# the following codes implement KNN_search in OpenSearch using DSL
@ -216,7 +217,7 @@ class OSConnection(DocStoreConnection):
if bqry:
s = s.query(bqry)
for field in highlightFields:
s = s.highlight(field,force_source=True,no_match_size=30,require_field_match=False)
s = s.highlight(field, force_source=True, no_match_size=30, require_field_match=False)
if orderBy:
orders = list()
@ -239,10 +240,10 @@ class OSConnection(DocStoreConnection):
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"OSConnection.search {str(indexNames)} query: " + json.dumps(q))
if use_knn:
del q["query"]
q["query"] = {"knn" : knn_query}
q["query"] = {"knn": knn_query}
for i in range(ATTEMPT_TIME):
try:
@ -328,7 +329,7 @@ class OSConnection(DocStoreConnection):
chunkId = condition["id"]
for i in range(ATTEMPT_TIME):
try:
self.os.update(index=indexName, id=chunkId, body={"doc":doc})
self.os.update(index=indexName, id=chunkId, body={"doc": doc})
return True
except Exception as e:
logger.exception(
@ -435,7 +436,7 @@ class OSConnection(DocStoreConnection):
logger.debug("OSConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
#print(Search().query(qry).to_dict(), flush=True)
# print(Search().query(qry).to_dict(), flush=True)
res = self.os.delete_by_query(
index=indexName,
body=Search().query(qry).to_dict(),

View File

@ -42,14 +42,16 @@ class RAGFlowOSS:
# If there is a default bucket, use the default bucket
actual_bucket = self.bucket if self.bucket else bucket
return method(self, actual_bucket, *args, **kwargs)
return wrapper
@staticmethod
def use_prefix_path(method):
def wrapper(self, bucket, fnm, *args, **kwargs):
# If the prefix path is set, use the prefix path
fnm = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm
return method(self, bucket, fnm, *args, **kwargs)
return wrapper
def __open__(self):
@ -171,4 +173,3 @@ class RAGFlowOSS:
self.__open__()
time.sleep(1)
return None

View File

@ -21,7 +21,6 @@ Utility functions for Raptor processing decisions.
import logging
from typing import Optional
# File extensions for structured data types
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
CSV_EXTENSIONS = {".csv", ".tsv"}
@ -40,12 +39,12 @@ def is_structured_file_type(file_type: Optional[str]) -> bool:
"""
if not file_type:
return False
# Normalize to lowercase and ensure leading dot
file_type = file_type.lower()
if not file_type.startswith("."):
file_type = f".{file_type}"
return file_type in STRUCTURED_EXTENSIONS
@ -61,23 +60,23 @@ def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) ->
True if PDF is being parsed as tabular data
"""
parser_config = parser_config or {}
# If using table parser, it's tabular
if parser_id and parser_id.lower() == "table":
return True
# Check if html4excel is enabled (Excel-like table parsing)
if parser_config.get("html4excel", False):
return True
return False
def should_skip_raptor(
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None,
raptor_config: Optional[dict] = None
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None,
raptor_config: Optional[dict] = None
) -> bool:
"""
Determine if Raptor should be skipped for a given document.
@ -97,30 +96,30 @@ def should_skip_raptor(
"""
parser_config = parser_config or {}
raptor_config = raptor_config or {}
# Check if auto-disable is explicitly disabled in config
if raptor_config.get("auto_disable_for_structured_data", True) is False:
logging.info("Raptor auto-disable is turned off via configuration")
return False
# Check for Excel/CSV files
if is_structured_file_type(file_type):
logging.info(f"Skipping Raptor for structured file type: {file_type}")
return True
# Check for tabular PDFs
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
logging.info(f"Skipping Raptor for tabular PDF (parser_id={parser_id})")
return True
return False
def get_skip_reason(
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None
) -> str:
"""
Get a human-readable reason why Raptor was skipped.
@ -134,12 +133,12 @@ def get_skip_reason(
Reason string, or empty string if Raptor should not be skipped
"""
parser_config = parser_config or {}
if is_structured_file_type(file_type):
return f"Structured data file ({file_type}) - Raptor auto-disabled"
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
return f"Tabular PDF (parser={parser_id}) - Raptor auto-disabled"
return ""

View File

@ -33,6 +33,7 @@ except Exception:
except Exception:
REDIS = {}
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer
@ -278,7 +279,8 @@ class RedisDB:
def decrby(self, key: str, decrement: int):
return self.REDIS.decrby(key, decrement)
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default",
increment: int = 1, ensure_minimum: int | None = None) -> int:
redis_key = f"{key_prefix}:{namespace}"
try:

View File

@ -46,6 +46,7 @@ class RAGFlowS3:
# If there is a default bucket, use the default bucket
actual_bucket = self.bucket if self.bucket else bucket
return method(self, actual_bucket, *args, **kwargs)
return wrapper
@staticmethod
@ -57,6 +58,7 @@ class RAGFlowS3:
if self.prefix_path:
fnm = f"{self.prefix_path}/{bucket}/{fnm}"
return method(self, bucket, fnm, *args, **kwargs)
return wrapper
def __open__(self):
@ -81,16 +83,16 @@ class RAGFlowS3:
s3_params['region_name'] = self.region_name
if self.endpoint_url:
s3_params['endpoint_url'] = self.endpoint_url
# Configure signature_version and addressing_style through Config object
if self.signature_version:
config_kwargs['signature_version'] = self.signature_version
if self.addressing_style:
config_kwargs['s3'] = {'addressing_style': self.addressing_style}
if config_kwargs:
s3_params['config'] = Config(**config_kwargs)
self.conn = [boto3.client('s3', **s3_params)]
except Exception:
logging.exception(f"Fail to connect at region {self.region_name} or endpoint {self.endpoint_url}")
@ -184,9 +186,9 @@ class RAGFlowS3:
for _ in range(10):
try:
r = self.conn[0].generate_presigned_url('get_object',
Params={'Bucket': bucket,
'Key': fnm},
ExpiresIn=expires)
Params={'Bucket': bucket,
'Key': fnm},
ExpiresIn=expires)
return r
except Exception:

View File

@ -30,7 +30,8 @@ class Tavily:
search_depth="advanced",
max_results=6
)
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]]
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res
in response["results"]]
except Exception as e:
logging.exception(e)
@ -64,5 +65,5 @@ class Tavily:
"count": 1,
"url": r["url"]
})
logging.info("[Tavily]R: "+r["content"][:128]+"...")
return {"chunks": chunks, "doc_aggs": aggs}
logging.info("[Tavily]R: " + r["content"][:128] + "...")
return {"chunks": chunks, "doc_aggs": aggs}