mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-29 16:05:35 +08:00
Fix IDE warnings (#12281)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -34,7 +34,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
if not ext:
|
||||
raise RuntimeError("No extension detected.")
|
||||
|
||||
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
|
||||
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma",
|
||||
".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
|
||||
raise RuntimeError(f"Extension {ext} is not supported yet.")
|
||||
|
||||
tmp_path = ""
|
||||
|
||||
@ -22,7 +22,7 @@ from deepdoc.parser.utils import get_text
|
||||
from rag.app import naive
|
||||
from rag.app.naive import by_plaintext, PARSERS
|
||||
from common.parser_config_utils import normalize_layout_recognizer
|
||||
from rag.nlp import bullets_category, is_english,remove_contents_table, \
|
||||
from rag.nlp import bullets_category, is_english, remove_contents_table, \
|
||||
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
|
||||
tokenize_chunks, attach_media_context
|
||||
from rag.nlp import rag_tokenizer
|
||||
@ -91,9 +91,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
filename, binary=binary, from_page=from_page, to_page=to_page)
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
tbls = vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
|
||||
tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs)
|
||||
# tbls = [((None, lns), None) for lns in tbls]
|
||||
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
|
||||
sections = [(item[0], item[1] if item[1] is not None else "") for item in sections if
|
||||
not isinstance(item[1], Image.Image)]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
@ -109,14 +110,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tables, pdf_parser = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
layout_recognizer = layout_recognizer,
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
pdf_cls=Pdf,
|
||||
layout_recognizer=layout_recognizer,
|
||||
mineru_llm_name=parser_model_name,
|
||||
**kwargs
|
||||
)
|
||||
@ -126,7 +127,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
@ -175,7 +176,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
for ck in hierarchical_merge(bull, sections, 5)]
|
||||
else:
|
||||
sections = [s.split("@") for s, _ in sections]
|
||||
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ]
|
||||
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections]
|
||||
chunks = naive_merge(
|
||||
sections,
|
||||
parser_config.get("chunk_token_num", 256),
|
||||
@ -199,6 +200,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
|
||||
|
||||
@ -26,13 +26,13 @@ import io
|
||||
|
||||
|
||||
def chunk(
|
||||
filename,
|
||||
binary=None,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
lang="Chinese",
|
||||
callback=None,
|
||||
**kwargs,
|
||||
filename,
|
||||
binary=None,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
lang="Chinese",
|
||||
callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Only eml is supported
|
||||
@ -93,7 +93,8 @@ def chunk(
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
|
||||
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
|
||||
(line, "") for line in
|
||||
HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
|
||||
]
|
||||
|
||||
st = timer()
|
||||
@ -126,7 +127,9 @@ def chunk(
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
|
||||
@ -29,8 +29,6 @@ from rag.app.naive import by_plaintext, PARSERS
|
||||
from common.parser_config_utils import normalize_layout_recognizer
|
||||
|
||||
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -58,37 +56,36 @@ class Docx(DocxParser):
|
||||
return [line for line in lines if line]
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
level_set = set()
|
||||
bull = bullets_category([p.text for p in self.doc.paragraphs])
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = docx_question_level(p, bull)
|
||||
if not p_text.strip("\n"):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
level_set = set()
|
||||
bull = bullets_category([p.text for p in self.doc.paragraphs])
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = docx_question_level(p, bull)
|
||||
if not p_text.strip("\n"):
|
||||
continue
|
||||
lines.append((question_level, p_text))
|
||||
level_set.add(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
lines.append((question_level, p_text))
|
||||
level_set.add(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
|
||||
sorted_levels = sorted(level_set)
|
||||
sorted_levels = sorted(level_set)
|
||||
|
||||
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
|
||||
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
|
||||
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
|
||||
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
|
||||
|
||||
root = Node(level=0, depth=h2_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [element for element in root.get_tree() if element]
|
||||
root = Node(level=0, depth=h2_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [element for element in root.get_tree() if element]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'''
|
||||
@ -121,8 +118,7 @@ class Pdf(PdfParser):
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts:".format(
|
||||
))
|
||||
logging.debug("layouts: {}".format((timer() - start)))
|
||||
self._naive_vertical_merge()
|
||||
|
||||
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
|
||||
@ -154,7 +150,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
chunks = Docx()(filename, binary)
|
||||
callback(0.7, "Finish parsing.")
|
||||
return tokenize_chunks(chunks, doc, eng, None)
|
||||
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer, parser_model_name = normalize_layout_recognizer(
|
||||
parser_config.get("layout_recognize", "DeepDOC")
|
||||
@ -168,14 +164,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
raw_sections, tables, pdf_parser = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
layout_recognizer = layout_recognizer,
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
pdf_cls=Pdf,
|
||||
layout_recognizer=layout_recognizer,
|
||||
mineru_llm_name=parser_model_name,
|
||||
**kwargs
|
||||
)
|
||||
@ -185,7 +181,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
|
||||
for txt, poss in raw_sections:
|
||||
sections.append(txt + poss)
|
||||
|
||||
@ -226,7 +222,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||
|
||||
|
||||
# Remove 'Contents' part
|
||||
remove_contents_table(sections, eng)
|
||||
|
||||
@ -234,7 +229,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
bull = bullets_category(sections)
|
||||
res = tree_merge(bull, sections, 2)
|
||||
|
||||
|
||||
if not res:
|
||||
callback(0.99, "No chunk parsed out.")
|
||||
|
||||
@ -243,9 +237,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
# chunks = hierarchical_merge(bull, sections, 5)
|
||||
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
|
||||
@ -20,15 +20,17 @@ import re
|
||||
|
||||
from common.constants import ParserType
|
||||
from io import BytesIO
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level, attach_media_context
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, \
|
||||
docx_question_level, attach_media_context
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from deepdoc.parser import PdfParser, DocxParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper, vision_figure_parser_docx_wrapper
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
from rag.app.naive import by_plaintext, PARSERS
|
||||
from common.parser_config_utils import normalize_layout_recognizer
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.MANUAL.value
|
||||
@ -129,11 +131,11 @@ class Docx(DocxParser):
|
||||
question_level, p_text = 0, ''
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
question_level, p_text = docx_question_level(p)
|
||||
if not question_level or question_level > 6: # not a question
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{p_text}'
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
last_image = self.concat_img(last_image, current_image)
|
||||
else: # is a question
|
||||
else: # is a question
|
||||
if last_answer or last_image:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
@ -159,14 +161,14 @@ class Docx(DocxParser):
|
||||
|
||||
tbls = []
|
||||
for tb in self.doc.tables:
|
||||
html= "<table>"
|
||||
html = "<table>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i+1, len(r.cells)):
|
||||
for j in range(i + 1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
@ -211,16 +213,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
kwargs.pop("parse_method", None)
|
||||
kwargs.pop("mineru_llm_name", None)
|
||||
sections, tbls, pdf_parser = pdf_parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
layout_recognizer = layout_recognizer,
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
pdf_cls=Pdf,
|
||||
layout_recognizer=layout_recognizer,
|
||||
mineru_llm_name=parser_model_name,
|
||||
parse_method = "manual",
|
||||
parse_method="manual",
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -237,10 +239,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if isinstance(poss, str):
|
||||
poss = pdf_parser.extract_positions(poss)
|
||||
if poss:
|
||||
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
|
||||
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
|
||||
pn = first[0]
|
||||
if isinstance(pn, list) and pn:
|
||||
pn = pn[0] # [pn] -> pn
|
||||
pn = pn[0] # [pn] -> pn
|
||||
poss[0] = (pn, *first[1:])
|
||||
|
||||
return (txt, layoutno, poss)
|
||||
@ -289,7 +291,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
|
||||
def tag(pn, left, right, top, bottom):
|
||||
if pn + left + right + top + bottom == 0:
|
||||
@ -312,7 +314,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
tk_cnt = num_tokens_from_string(txt)
|
||||
if sec_id > -1:
|
||||
last_sid = sec_id
|
||||
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
|
||||
@ -325,7 +327,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
docx_parser = Docx()
|
||||
ti_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
tbls = vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
|
||||
tbls = vision_figure_parser_docx_wrapper(sections=ti_list, tbls=tbls, callback=callback, **kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for text, image in ti_list:
|
||||
d = copy.deepcopy(doc)
|
||||
|
||||
101
rag/app/naive.py
101
rag/app/naive.py
@ -31,16 +31,20 @@ from common.token_utils import num_tokens_from_string
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html
|
||||
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
|
||||
from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
|
||||
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, \
|
||||
PdfParser, TxtParser
|
||||
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_docx_wrapper, \
|
||||
vision_figure_parser_pdf_wrapper
|
||||
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||
from deepdoc.parser.docling_parser import DoclingParser
|
||||
from deepdoc.parser.tcadp_parser import TCADPParser
|
||||
from common.parser_config_utils import normalize_layout_recognizer
|
||||
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
|
||||
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, \
|
||||
tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
|
||||
|
||||
|
||||
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
|
||||
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
|
||||
**kwargs):
|
||||
callback = callback
|
||||
binary = binary
|
||||
pdf_parser = pdf_cls() if pdf_cls else Pdf()
|
||||
@ -58,17 +62,17 @@ def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
|
||||
|
||||
|
||||
def by_mineru(
|
||||
filename,
|
||||
binary=None,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
lang="Chinese",
|
||||
callback=None,
|
||||
pdf_cls=None,
|
||||
parse_method: str = "raw",
|
||||
mineru_llm_name: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
**kwargs,
|
||||
filename,
|
||||
binary=None,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
lang="Chinese",
|
||||
callback=None,
|
||||
pdf_cls=None,
|
||||
parse_method: str = "raw",
|
||||
mineru_llm_name: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
pdf_parser = None
|
||||
if tenant_id:
|
||||
@ -106,7 +110,8 @@ def by_mineru(
|
||||
return None, None, None
|
||||
|
||||
|
||||
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
|
||||
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
|
||||
**kwargs):
|
||||
pdf_parser = DoclingParser()
|
||||
parse_method = kwargs.get("parse_method", "raw")
|
||||
|
||||
@ -125,7 +130,7 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs):
|
||||
def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None, **kwargs):
|
||||
tcadp_parser = TCADPParser()
|
||||
|
||||
if not tcadp_parser.check_installation():
|
||||
@ -168,10 +173,10 @@ def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=No
|
||||
|
||||
|
||||
PARSERS = {
|
||||
"deepdoc": by_deepdoc,
|
||||
"mineru": by_mineru,
|
||||
"docling": by_docling,
|
||||
"tcadp": by_tcadp,
|
||||
"deepdoc": by_deepdoc,
|
||||
"mineru": by_mineru,
|
||||
"docling": by_docling,
|
||||
"tcadp": by_tcadp,
|
||||
"plaintext": by_plaintext, # default
|
||||
}
|
||||
|
||||
@ -264,7 +269,7 @@ class Docx(DocxParser):
|
||||
|
||||
# Find the nearest heading paragraph in reverse order
|
||||
nearest_title = None
|
||||
for i in range(len(blocks)-1, -1, -1):
|
||||
for i in range(len(blocks) - 1, -1, -1):
|
||||
block_type, pos, block = blocks[i]
|
||||
if pos >= target_table_pos: # Skip blocks after the table
|
||||
continue
|
||||
@ -293,7 +298,7 @@ class Docx(DocxParser):
|
||||
# Find all parent headings, allowing cross-level search
|
||||
while current_level > 1:
|
||||
found = False
|
||||
for i in range(len(blocks)-1, -1, -1):
|
||||
for i in range(len(blocks) - 1, -1, -1):
|
||||
block_type, pos, block = blocks[i]
|
||||
if pos >= target_table_pos: # Skip blocks after the table
|
||||
continue
|
||||
@ -426,7 +431,8 @@ class Docx(DocxParser):
|
||||
|
||||
try:
|
||||
if inline_images:
|
||||
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64))
|
||||
result = mammoth.convert_to_html(docx_file,
|
||||
convert_image=mammoth.images.img_element(_convert_image_to_base64))
|
||||
else:
|
||||
result = mammoth.convert_to_html(docx_file)
|
||||
|
||||
@ -621,6 +627,7 @@ class Markdown(MarkdownParser):
|
||||
return sections, tbls, section_images
|
||||
return sections, tbls
|
||||
|
||||
|
||||
def load_from_xml_v2(baseURI, rels_item_xml):
|
||||
"""
|
||||
Return |_SerializedRelationships| instance loaded with the
|
||||
@ -636,6 +643,7 @@ def load_from_xml_v2(baseURI, rels_item_xml):
|
||||
srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
|
||||
return srels
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, excel, txt.
|
||||
@ -651,7 +659,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
|
||||
|
||||
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
|
||||
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode(
|
||||
'latin1').decode('utf-8')
|
||||
cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
|
||||
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
|
||||
if cust_child_deli:
|
||||
@ -685,7 +694,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
# Recursively chunk each embedded file and collect results
|
||||
for embed_filename, embed_bytes in embeds:
|
||||
try:
|
||||
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False, **kwargs) or []
|
||||
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False,
|
||||
**kwargs) or []
|
||||
embed_res.extend(sub_res)
|
||||
except Exception as e:
|
||||
if callback:
|
||||
@ -704,7 +714,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
|
||||
except Exception as e:
|
||||
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
|
||||
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
|
||||
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False,
|
||||
**kwargs)
|
||||
url_res.extend(sub_url_res)
|
||||
|
||||
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
|
||||
@ -747,14 +758,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tables, pdf_parser = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
layout_recognizer = layout_recognizer,
|
||||
mineru_llm_name = parser_model_name,
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
layout_recognizer=layout_recognizer,
|
||||
mineru_llm_name=parser_model_name,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -846,9 +857,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
else:
|
||||
section_images = [None] * len(sections)
|
||||
section_images[idx] = combined_image
|
||||
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
|
||||
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=[
|
||||
((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
|
||||
boosted_figures = markdown_vision_parser(callback=callback)
|
||||
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]), sections[idx][1])
|
||||
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]),
|
||||
sections[idx][1])
|
||||
|
||||
else:
|
||||
logging.warning("No visual model detected. Skipping figure parsing enhancement.")
|
||||
@ -945,7 +958,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
has_images = merged_images and any(img is not None for img in merged_images)
|
||||
|
||||
if has_images:
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images, child_delimiters_pattern=child_deli))
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images,
|
||||
child_delimiters_pattern=child_deli))
|
||||
else:
|
||||
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser, child_delimiters_pattern=child_deli))
|
||||
else:
|
||||
@ -955,10 +969,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
|
||||
if section_images:
|
||||
chunks, images = naive_merge_with_images(sections, section_images,
|
||||
int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
|
||||
int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
res.extend(
|
||||
tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
|
||||
else:
|
||||
chunks = naive_merge(
|
||||
sections, int(parser_config.get(
|
||||
@ -993,7 +1008,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
@ -26,6 +26,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
|
||||
from rag.app.naive import by_plaintext, PARSERS
|
||||
from common.parser_config_utils import normalize_layout_recognizer
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
@ -95,14 +96,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tbls, pdf_parser = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
layout_recognizer = layout_recognizer,
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
pdf_cls=Pdf,
|
||||
layout_recognizer=layout_recognizer,
|
||||
mineru_llm_name=parser_model_name,
|
||||
**kwargs
|
||||
)
|
||||
@ -112,9 +113,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
@ -172,7 +173,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
@ -20,7 +20,8 @@ import re
|
||||
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
|
||||
from common.constants import ParserType
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks, attach_media_context
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, \
|
||||
tokenize_chunks, attach_media_context
|
||||
from deepdoc.parser import PdfParser
|
||||
import numpy as np
|
||||
from rag.app.naive import by_plaintext, PARSERS
|
||||
@ -66,7 +67,7 @@ class Pdf(PdfParser):
|
||||
# clean mess
|
||||
if column_width < self.page_images[0].size[0] / zoomin / 2:
|
||||
logging.debug("two_column................... {} {}".format(column_width,
|
||||
self.page_images[0].size[0] / zoomin / 2))
|
||||
self.page_images[0].size[0] / zoomin / 2))
|
||||
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
|
||||
for b in self.boxes:
|
||||
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
|
||||
@ -89,7 +90,7 @@ class Pdf(PdfParser):
|
||||
title = ""
|
||||
authors = []
|
||||
i = 0
|
||||
while i < min(32, len(self.boxes)-1):
|
||||
while i < min(32, len(self.boxes) - 1):
|
||||
b = self.boxes[i]
|
||||
i += 1
|
||||
if b.get("layoutno", "").find("title") >= 0:
|
||||
@ -190,8 +191,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
"tables": tables
|
||||
}
|
||||
|
||||
tbls=paper["tables"]
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
tbls = paper["tables"]
|
||||
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
|
||||
paper["tables"] = tbls
|
||||
else:
|
||||
raise NotImplementedError("file type not supported yet(pdf supported)")
|
||||
@ -329,6 +330,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
|
||||
@ -51,7 +51,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
}
|
||||
)
|
||||
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
|
||||
ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
|
||||
ans = asyncio.run(
|
||||
cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
ans += "\n" + ans
|
||||
tokenize(doc, ans, eng)
|
||||
|
||||
@ -249,7 +249,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(a, b):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
|
||||
@ -102,9 +102,9 @@ class Pdf(PdfParser):
|
||||
self._text_merge()
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
#self._naive_vertical_merge()
|
||||
# self._naive_vertical_merge()
|
||||
# self._concat_downward()
|
||||
#self._filter_forpages()
|
||||
# self._filter_forpages()
|
||||
logging.debug("layouts: {}".format(timer() - start))
|
||||
sections = [b["text"] for b in self.boxes]
|
||||
bull_x0_list = []
|
||||
@ -114,12 +114,14 @@ class Pdf(PdfParser):
|
||||
qai_list = []
|
||||
last_q, last_a, last_tag = '', '', ''
|
||||
last_index = -1
|
||||
last_box = {'text':''}
|
||||
last_box = {'text': ''}
|
||||
last_bull = None
|
||||
|
||||
def sort_key(element):
|
||||
tbls_pn = element[1][0][0]
|
||||
tbls_top = element[1][0][3]
|
||||
return tbls_pn, tbls_top
|
||||
|
||||
tbls.sort(key=sort_key)
|
||||
tbl_index = 0
|
||||
last_pn, last_bottom = 0, 0
|
||||
@ -133,28 +135,32 @@ class Pdf(PdfParser):
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
if not has_bull: # No question bullet
|
||||
if not last_q:
|
||||
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
|
||||
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
|
||||
tbl_index += 1
|
||||
continue
|
||||
else:
|
||||
sum_tag = line_tag
|
||||
sum_section = section
|
||||
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer
|
||||
while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (
|
||||
tbl_pn < line_pn)): # add image at the middle of current answer
|
||||
sum_tag = f'{tbl_tag}{sum_tag}'
|
||||
sum_section = f'{tbl_text}{sum_section}'
|
||||
tbl_index += 1
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
|
||||
tbl_index)
|
||||
last_a = f'{last_a}{sum_section}'
|
||||
last_tag = f'{last_tag}{sum_tag}'
|
||||
else:
|
||||
if last_q:
|
||||
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer
|
||||
while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (
|
||||
tbl_pn < line_pn)): # add image at the end of last answer
|
||||
last_tag = f'{last_tag}{tbl_tag}'
|
||||
last_a = f'{last_a}{tbl_text}'
|
||||
tbl_index += 1
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
|
||||
tbl_index)
|
||||
image, poss = self.crop(last_tag, need_position=True)
|
||||
qai_list.append((last_q, last_a, image, poss))
|
||||
last_q, last_a, last_tag = '', '', ''
|
||||
@ -171,7 +177,7 @@ class Pdf(PdfParser):
|
||||
def get_tbls_info(self, tbls, tbl_index):
|
||||
if tbl_index >= len(tbls):
|
||||
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
|
||||
tbl_pn = tbls[tbl_index][1][0][0]+1
|
||||
tbl_pn = tbls[tbl_index][1][0][0] + 1
|
||||
tbl_left = tbls[tbl_index][1][0][1]
|
||||
tbl_right = tbls[tbl_index][1][0][2]
|
||||
tbl_top = tbls[tbl_index][1][0][3]
|
||||
@ -210,11 +216,11 @@ class Docx(DocxParser):
|
||||
question_level, p_text = 0, ''
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
question_level, p_text = docx_question_level(p)
|
||||
if not question_level or question_level > 6: # not a question
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{p_text}'
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
last_image = concat_img(last_image, current_image)
|
||||
else: # is a question
|
||||
else: # is a question
|
||||
if last_answer or last_image:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
@ -240,14 +246,14 @@ class Docx(DocxParser):
|
||||
|
||||
tbls = []
|
||||
for tb in self.doc.tables:
|
||||
html= "<table>"
|
||||
html = "<table>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i+1, len(r.cells)):
|
||||
for j in range(i + 1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
@ -356,7 +362,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
if question:
|
||||
answer += "\n" + lines[i]
|
||||
else:
|
||||
fails.append(str(i+1))
|
||||
fails.append(str(i + 1))
|
||||
elif len(arr) == 2:
|
||||
if question and answer:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
|
||||
@ -429,13 +435,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
if not code_block:
|
||||
question_level, question = mdQuestionLevel(line)
|
||||
|
||||
if not question_level or question_level > 6: # not a question
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{line}'
|
||||
else: # is a question
|
||||
else: # is a question
|
||||
if last_answer.strip():
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
res.append(beAdoc(deepcopy(doc), sum_question,
|
||||
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
last_answer = ''
|
||||
|
||||
i = question_level
|
||||
@ -447,13 +454,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
if last_answer.strip():
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
res.append(beAdoc(deepcopy(doc), sum_question,
|
||||
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
docx_parser = Docx()
|
||||
qai_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for i, (q, a, image) in enumerate(qai_list):
|
||||
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
|
||||
@ -466,6 +474,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
@ -64,7 +64,8 @@ def remote_call(filename, binary):
|
||||
del resume[k]
|
||||
|
||||
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
|
||||
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
|
||||
"updated_at": datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S")}]))
|
||||
resume = step_two.parse(resume)
|
||||
return resume
|
||||
except Exception:
|
||||
@ -171,6 +172,9 @@ def chunk(filename, binary=None, callback=None, **kwargs):
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(a, b):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
|
||||
@ -51,14 +51,15 @@ class Excel(ExcelParser):
|
||||
tables = []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
images = Excel._extract_images_from_worksheet(ws,sheetname=sheetname)
|
||||
images = Excel._extract_images_from_worksheet(ws, sheetname=sheetname)
|
||||
if images:
|
||||
image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, **kwargs)
|
||||
image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback,
|
||||
**kwargs)
|
||||
if image_descriptions and len(image_descriptions) == len(images):
|
||||
for i, bf in enumerate(image_descriptions):
|
||||
images[i]["image_description"] = "\n".join(bf[0][1])
|
||||
for img in images:
|
||||
if (img["span_type"] == "single_cell"and img.get("image_description")):
|
||||
if (img["span_type"] == "single_cell" and img.get("image_description")):
|
||||
pending_cell_images.append(img)
|
||||
else:
|
||||
flow_images.append(img)
|
||||
@ -113,16 +114,17 @@ class Excel(ExcelParser):
|
||||
tables.append(
|
||||
(
|
||||
(
|
||||
img["image"], # Image.Image
|
||||
[img["image_description"]] # description list (must be list)
|
||||
img["image"], # Image.Image
|
||||
[img["image_description"]] # description list (must be list)
|
||||
),
|
||||
[
|
||||
(0, 0, 0, 0, 0) # dummy position
|
||||
(0, 0, 0, 0, 0) # dummy position
|
||||
]
|
||||
)
|
||||
)
|
||||
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
return res,tables
|
||||
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
return res, tables
|
||||
|
||||
def _parse_headers(self, ws, rows):
|
||||
if len(rows) == 0:
|
||||
@ -315,14 +317,15 @@ def trans_bool(s):
|
||||
def column_data_type(arr):
|
||||
arr = list(arr)
|
||||
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
|
||||
trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
|
||||
trans = {t: f for f, t in
|
||||
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
|
||||
float_flag = False
|
||||
for a in arr:
|
||||
if a is None:
|
||||
continue
|
||||
if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
|
||||
counts["int"] += 1
|
||||
if int(str(a)) > 2**63 - 1:
|
||||
if int(str(a)) > 2 ** 63 - 1:
|
||||
float_flag = True
|
||||
break
|
||||
elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
|
||||
@ -370,7 +373,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = Excel()
|
||||
dfs,tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs)
|
||||
dfs, tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs)
|
||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
@ -389,7 +392,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
continue
|
||||
rows.append(row)
|
||||
|
||||
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
|
||||
elif re.search(r"\.csv$", filename, re.IGNORECASE):
|
||||
@ -406,7 +410,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
fails = []
|
||||
rows = []
|
||||
|
||||
for i, row in enumerate(all_rows[1 + from_page : 1 + to_page]):
|
||||
for i, row in enumerate(all_rows[1 + from_page: 1 + to_page]):
|
||||
if len(row) != len(headers):
|
||||
fails.append(str(i + from_page))
|
||||
continue
|
||||
@ -415,7 +419,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
callback(
|
||||
0.3,
|
||||
(f"Extract records: {from_page}~{from_page + len(rows)}" +
|
||||
(f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))
|
||||
(f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))
|
||||
)
|
||||
|
||||
dfs = [pd.DataFrame(rows, columns=headers)]
|
||||
@ -445,7 +449,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
df[clmns[j]] = cln
|
||||
if ty == "text":
|
||||
txts.extend([str(c) for c in cln if c])
|
||||
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))]
|
||||
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
|
||||
range(len(clmns))]
|
||||
|
||||
eng = lang.lower() == "english" # is_english(txts)
|
||||
for ii, row in df.iterrows():
|
||||
@ -477,7 +482,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
|
||||
@ -141,17 +141,20 @@ def label_question(question, kbs):
|
||||
if not tag_kbs:
|
||||
return tags
|
||||
tags = settings.retriever.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
kb.parser_config.get("topn_tags", 3)
|
||||
)
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
kb.parser_config.get("topn_tags", 3)
|
||||
)
|
||||
return tags
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
@ -263,7 +263,7 @@ class SparkTTS(Base):
|
||||
raise Exception(error)
|
||||
|
||||
def on_close(self, ws, close_status_code, close_msg):
|
||||
self.audio_queue.put(None) # 放入 None 作为结束标志
|
||||
self.audio_queue.put(None) # None is terminator
|
||||
|
||||
def on_open(self, ws):
|
||||
def run(*args):
|
||||
|
||||
@ -273,7 +273,7 @@ def tokenize(d, txt, eng):
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
|
||||
|
||||
def split_with_pattern(d, pattern:str, content:str, eng) -> list:
|
||||
def split_with_pattern(d, pattern: str, content: str, eng) -> list:
|
||||
docs = []
|
||||
txts = [txt for txt in re.split(r"(%s)" % pattern, content, flags=re.DOTALL)]
|
||||
for j in range(0, len(txts), 2):
|
||||
@ -281,7 +281,7 @@ def split_with_pattern(d, pattern:str, content:str, eng) -> list:
|
||||
if not txt:
|
||||
continue
|
||||
if j + 1 < len(txts):
|
||||
txt += txts[j+1]
|
||||
txt += txts[j + 1]
|
||||
dd = copy.deepcopy(d)
|
||||
tokenize(dd, txt, eng)
|
||||
docs.append(dd)
|
||||
@ -304,7 +304,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern=
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
add_positions(d, [[ii]*5])
|
||||
add_positions(d, [[ii] * 5])
|
||||
|
||||
if child_delimiters_pattern:
|
||||
d["mom_with_weight"] = ck
|
||||
@ -325,7 +325,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images, child_delimiters_patte
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
d["image"] = image
|
||||
add_positions(d, [[ii]*5])
|
||||
add_positions(d, [[ii] * 5])
|
||||
if child_delimiters_pattern:
|
||||
d["mom_with_weight"] = ck
|
||||
res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng))
|
||||
@ -658,7 +658,8 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
|
||||
if "content_ltks" in ck:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(combined)
|
||||
if "content_sm_ltks" in ck:
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
|
||||
ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
|
||||
|
||||
if positioned_indices:
|
||||
chunks[:] = [chunks[i] for i in ordered_indices]
|
||||
@ -764,8 +765,8 @@ def not_title(txt):
|
||||
return True
|
||||
return re.search(r"[,;,。;!!]", txt)
|
||||
|
||||
def tree_merge(bull, sections, depth):
|
||||
|
||||
def tree_merge(bull, sections, depth):
|
||||
if not sections or bull < 0:
|
||||
return sections
|
||||
if isinstance(sections[0], type("")):
|
||||
@ -777,16 +778,17 @@ def tree_merge(bull, sections, depth):
|
||||
|
||||
def get_level(bull, section):
|
||||
text, layout = section
|
||||
text = re.sub(r"\u3000", " ", text).strip()
|
||||
text = re.sub(r"\u3000", " ", text).strip()
|
||||
|
||||
for i, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, text.strip()):
|
||||
return i+1, text
|
||||
return i + 1, text
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(text):
|
||||
return len(BULLET_PATTERN[bull])+1, text
|
||||
return len(BULLET_PATTERN[bull]) + 1, text
|
||||
else:
|
||||
return len(BULLET_PATTERN[bull])+2, text
|
||||
return len(BULLET_PATTERN[bull]) + 2, text
|
||||
|
||||
level_set = set()
|
||||
lines = []
|
||||
for section in sections:
|
||||
@ -812,8 +814,8 @@ def tree_merge(bull, sections, depth):
|
||||
|
||||
return [element for element in root.get_tree() if element]
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
if not sections or bull < 0:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
@ -922,10 +924,10 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
|
||||
if cks:
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
@ -957,7 +959,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
|
||||
return cks
|
||||
|
||||
for sec, pos in sections:
|
||||
add_chunk("\n"+sec, pos)
|
||||
add_chunk("\n" + sec, pos)
|
||||
|
||||
return cks
|
||||
|
||||
@ -978,10 +980,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
|
||||
if cks:
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
@ -1025,9 +1027,9 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
||||
if isinstance(text, tuple):
|
||||
text_str = text[0]
|
||||
text_pos = text[1] if len(text) > 1 else ""
|
||||
add_chunk("\n"+text_str, image, text_pos)
|
||||
add_chunk("\n" + text_str, image, text_pos)
|
||||
else:
|
||||
add_chunk("\n"+text, image)
|
||||
add_chunk("\n" + text, image)
|
||||
|
||||
return cks, result_images
|
||||
|
||||
@ -1042,7 +1044,7 @@ def docx_question_level(p, bull=-1):
|
||||
for j, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, txt):
|
||||
return j + 1, txt
|
||||
return len(BULLET_PATTERN[bull])+1, txt
|
||||
return len(BULLET_PATTERN[bull]) + 1, txt
|
||||
|
||||
|
||||
def concat_img(img1, img2):
|
||||
@ -1211,7 +1213,7 @@ class Node:
|
||||
child = node.get_children()
|
||||
|
||||
if level == 0 and texts:
|
||||
tree_list.append("\n".join(titles+texts))
|
||||
tree_list.append("\n".join(titles + texts))
|
||||
|
||||
# Titles within configured depth are accumulated into the current path
|
||||
if 1 <= level <= self.depth:
|
||||
|
||||
@ -205,11 +205,11 @@ class FulltextQueryer(QueryBase):
|
||||
s = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
if k in dtwt:
|
||||
s += v #* dtwt[k]
|
||||
s += v # * dtwt[k]
|
||||
q = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
q += v #* v
|
||||
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
|
||||
q += v # * v
|
||||
return s / q # math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
|
||||
|
||||
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
|
||||
if isinstance(content_tks, str):
|
||||
@ -232,4 +232,5 @@ class FulltextQueryer(QueryBase):
|
||||
keywords.append(f"{tk}^{w}")
|
||||
|
||||
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
|
||||
{"minimum_should_match": min(3, len(keywords) / 10), "original_query": " ".join(origin_keywords)})
|
||||
{"minimum_should_match": min(3, len(keywords) / 10),
|
||||
"original_query": " ".join(origin_keywords)})
|
||||
|
||||
@ -66,7 +66,8 @@ class Dealer:
|
||||
if key in req and req[key] is not None:
|
||||
condition[field] = req[key]
|
||||
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
||||
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
|
||||
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
|
||||
"removed_kwd"]:
|
||||
if key in req and req[key] is not None:
|
||||
condition[key] = req[key]
|
||||
return condition
|
||||
@ -141,7 +142,8 @@ class Dealer:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
orderBy, offset, limit, idx_names, kb_ids,
|
||||
rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
@ -218,8 +220,9 @@ class Dealer:
|
||||
ans_v, _ = embd_mdl.encode(pieces_)
|
||||
for i in range(len(chunk_v)):
|
||||
if len(ans_v[0]) != len(chunk_v[i]):
|
||||
chunk_v[i] = [0.0]*len(ans_v[0])
|
||||
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||
chunk_v[i] = [0.0] * len(ans_v[0])
|
||||
logging.warning(
|
||||
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||
|
||||
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
||||
len(ans_v[0]), len(chunk_v[0]))
|
||||
@ -273,7 +276,7 @@ class Dealer:
|
||||
if not query_rfea:
|
||||
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
||||
|
||||
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
for i in search_res.ids:
|
||||
nor, denor = 0, 0
|
||||
if not search_res.field[i].get(TAG_FLD):
|
||||
@ -286,8 +289,8 @@ class Dealer:
|
||||
if denor == 0:
|
||||
rank_fea.append(0)
|
||||
else:
|
||||
rank_fea.append(nor/np.sqrt(denor)/q_denor)
|
||||
return np.array(rank_fea)*10. + pageranks
|
||||
rank_fea.append(nor / np.sqrt(denor) / q_denor)
|
||||
return np.array(rank_fea) * 10. + pageranks
|
||||
|
||||
def rerank(self, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks",
|
||||
@ -358,21 +361,21 @@ class Dealer:
|
||||
rag_tokenizer.tokenize(inst).split())
|
||||
|
||||
def retrieval(
|
||||
self,
|
||||
question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
page,
|
||||
page_size,
|
||||
similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3,
|
||||
top=1024,
|
||||
doc_ids=None,
|
||||
aggs=True,
|
||||
rerank_mdl=None,
|
||||
highlight=False,
|
||||
rank_feature: dict | None = {PAGERANK_FLD: 10},
|
||||
self,
|
||||
question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
page,
|
||||
page_size,
|
||||
similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3,
|
||||
top=1024,
|
||||
doc_ids=None,
|
||||
aggs=True,
|
||||
rerank_mdl=None,
|
||||
highlight=False,
|
||||
rank_feature: dict | None = {PAGERANK_FLD: 10},
|
||||
):
|
||||
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
||||
if not question:
|
||||
@ -395,7 +398,8 @@ class Dealer:
|
||||
if isinstance(tenant_ids, str):
|
||||
tenant_ids = tenant_ids.split(",")
|
||||
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
|
||||
rank_feature=rank_feature)
|
||||
|
||||
if rerank_mdl and sres.total > 0:
|
||||
sim, tsim, vsim = self.rerank_by_model(
|
||||
@ -558,13 +562,14 @@ class Dealer:
|
||||
|
||||
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
|
||||
idx_nm = index_name(tenant_id)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
|
||||
keywords_topn)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return False
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
|
||||
return True
|
||||
@ -580,11 +585,11 @@ class Dealer:
|
||||
if not aggs:
|
||||
return {}
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||
|
||||
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
|
||||
def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
|
||||
if not chunks:
|
||||
return []
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
@ -594,9 +599,10 @@ class Dealer:
|
||||
ranks[ck["doc_id"]] = 0
|
||||
ranks[ck["doc_id"]] += ck["similarity"]
|
||||
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
|
||||
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
|
||||
doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
|
||||
kb_ids = [doc_id2kb_id[doc_id]]
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
|
||||
OrderByExpr(), 0, 128, idx_nms,
|
||||
kb_ids)
|
||||
toc = []
|
||||
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
|
||||
@ -608,7 +614,7 @@ class Dealer:
|
||||
if not toc:
|
||||
return chunks
|
||||
|
||||
ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2))
|
||||
ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2))
|
||||
if not ids:
|
||||
return chunks
|
||||
|
||||
@ -644,9 +650,9 @@ class Dealer:
|
||||
break
|
||||
chunks.append(d)
|
||||
|
||||
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
|
||||
return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]
|
||||
|
||||
def retrieval_by_children(self, chunks:list[dict], tenant_ids:list[str]):
|
||||
def retrieval_by_children(self, chunks: list[dict], tenant_ids: list[str]):
|
||||
if not chunks:
|
||||
return []
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
@ -692,4 +698,4 @@ class Dealer:
|
||||
break
|
||||
chunks.append(d)
|
||||
|
||||
return sorted(chunks, key=lambda x:x["similarity"]*-1)
|
||||
return sorted(chunks, key=lambda x: x["similarity"] * -1)
|
||||
|
||||
@ -14,129 +14,131 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
m = set(["赵","钱","孙","李",
|
||||
"周","吴","郑","王",
|
||||
"冯","陈","褚","卫",
|
||||
"蒋","沈","韩","杨",
|
||||
"朱","秦","尤","许",
|
||||
"何","吕","施","张",
|
||||
"孔","曹","严","华",
|
||||
"金","魏","陶","姜",
|
||||
"戚","谢","邹","喻",
|
||||
"柏","水","窦","章",
|
||||
"云","苏","潘","葛",
|
||||
"奚","范","彭","郎",
|
||||
"鲁","韦","昌","马",
|
||||
"苗","凤","花","方",
|
||||
"俞","任","袁","柳",
|
||||
"酆","鲍","史","唐",
|
||||
"费","廉","岑","薛",
|
||||
"雷","贺","倪","汤",
|
||||
"滕","殷","罗","毕",
|
||||
"郝","邬","安","常",
|
||||
"乐","于","时","傅",
|
||||
"皮","卞","齐","康",
|
||||
"伍","余","元","卜",
|
||||
"顾","孟","平","黄",
|
||||
"和","穆","萧","尹",
|
||||
"姚","邵","湛","汪",
|
||||
"祁","毛","禹","狄",
|
||||
"米","贝","明","臧",
|
||||
"计","伏","成","戴",
|
||||
"谈","宋","茅","庞",
|
||||
"熊","纪","舒","屈",
|
||||
"项","祝","董","梁",
|
||||
"杜","阮","蓝","闵",
|
||||
"席","季","麻","强",
|
||||
"贾","路","娄","危",
|
||||
"江","童","颜","郭",
|
||||
"梅","盛","林","刁",
|
||||
"钟","徐","邱","骆",
|
||||
"高","夏","蔡","田",
|
||||
"樊","胡","凌","霍",
|
||||
"虞","万","支","柯",
|
||||
"昝","管","卢","莫",
|
||||
"经","房","裘","缪",
|
||||
"干","解","应","宗",
|
||||
"丁","宣","贲","邓",
|
||||
"郁","单","杭","洪",
|
||||
"包","诸","左","石",
|
||||
"崔","吉","钮","龚",
|
||||
"程","嵇","邢","滑",
|
||||
"裴","陆","荣","翁",
|
||||
"荀","羊","於","惠",
|
||||
"甄","曲","家","封",
|
||||
"芮","羿","储","靳",
|
||||
"汲","邴","糜","松",
|
||||
"井","段","富","巫",
|
||||
"乌","焦","巴","弓",
|
||||
"牧","隗","山","谷",
|
||||
"车","侯","宓","蓬",
|
||||
"全","郗","班","仰",
|
||||
"秋","仲","伊","宫",
|
||||
"宁","仇","栾","暴",
|
||||
"甘","钭","厉","戎",
|
||||
"祖","武","符","刘",
|
||||
"景","詹","束","龙",
|
||||
"叶","幸","司","韶",
|
||||
"郜","黎","蓟","薄",
|
||||
"印","宿","白","怀",
|
||||
"蒲","邰","从","鄂",
|
||||
"索","咸","籍","赖",
|
||||
"卓","蔺","屠","蒙",
|
||||
"池","乔","阴","鬱",
|
||||
"胥","能","苍","双",
|
||||
"闻","莘","党","翟",
|
||||
"谭","贡","劳","逄",
|
||||
"姬","申","扶","堵",
|
||||
"冉","宰","郦","雍",
|
||||
"郤","璩","桑","桂",
|
||||
"濮","牛","寿","通",
|
||||
"边","扈","燕","冀",
|
||||
"郏","浦","尚","农",
|
||||
"温","别","庄","晏",
|
||||
"柴","瞿","阎","充",
|
||||
"慕","连","茹","习",
|
||||
"宦","艾","鱼","容",
|
||||
"向","古","易","慎",
|
||||
"戈","廖","庾","终",
|
||||
"暨","居","衡","步",
|
||||
"都","耿","满","弘",
|
||||
"匡","国","文","寇",
|
||||
"广","禄","阙","东",
|
||||
"欧","殳","沃","利",
|
||||
"蔚","越","夔","隆",
|
||||
"师","巩","厍","聂",
|
||||
"晁","勾","敖","融",
|
||||
"冷","訾","辛","阚",
|
||||
"那","简","饶","空",
|
||||
"曾","母","沙","乜",
|
||||
"养","鞠","须","丰",
|
||||
"巢","关","蒯","相",
|
||||
"查","后","荆","红",
|
||||
"游","竺","权","逯",
|
||||
"盖","益","桓","公",
|
||||
"兰","原","乞","西","阿","肖","丑","位","曽","巨","德","代","圆","尉","仵","纳","仝","脱","丘","但","展","迪","付","覃","晗","特","隋","苑","奥","漆","谌","郄","练","扎","邝","渠","信","门","陳","化","原","密","泮","鹿","赫",
|
||||
"万俟","司马","上官","欧阳",
|
||||
"夏侯","诸葛","闻人","东方",
|
||||
"赫连","皇甫","尉迟","公羊",
|
||||
"澹台","公冶","宗政","濮阳",
|
||||
"淳于","单于","太叔","申屠",
|
||||
"公孙","仲孙","轩辕","令狐",
|
||||
"钟离","宇文","长孙","慕容",
|
||||
"鲜于","闾丘","司徒","司空",
|
||||
"亓官","司寇","仉督","子车",
|
||||
"颛孙","端木","巫马","公西",
|
||||
"漆雕","乐正","壤驷","公良",
|
||||
"拓跋","夹谷","宰父","榖梁",
|
||||
"晋","楚","闫","法","汝","鄢","涂","钦",
|
||||
"段干","百里","东郭","南门",
|
||||
"呼延","归","海","羊舌","微","生",
|
||||
"岳","帅","缑","亢","况","后","有","琴",
|
||||
"梁丘","左丘","东门","西门",
|
||||
"商","牟","佘","佴","伯","赏","南宫",
|
||||
"墨","哈","谯","笪","年","爱","阳","佟",
|
||||
"第五","言","福"])
|
||||
m = set(["赵", "钱", "孙", "李",
|
||||
"周", "吴", "郑", "王",
|
||||
"冯", "陈", "褚", "卫",
|
||||
"蒋", "沈", "韩", "杨",
|
||||
"朱", "秦", "尤", "许",
|
||||
"何", "吕", "施", "张",
|
||||
"孔", "曹", "严", "华",
|
||||
"金", "魏", "陶", "姜",
|
||||
"戚", "谢", "邹", "喻",
|
||||
"柏", "水", "窦", "章",
|
||||
"云", "苏", "潘", "葛",
|
||||
"奚", "范", "彭", "郎",
|
||||
"鲁", "韦", "昌", "马",
|
||||
"苗", "凤", "花", "方",
|
||||
"俞", "任", "袁", "柳",
|
||||
"酆", "鲍", "史", "唐",
|
||||
"费", "廉", "岑", "薛",
|
||||
"雷", "贺", "倪", "汤",
|
||||
"滕", "殷", "罗", "毕",
|
||||
"郝", "邬", "安", "常",
|
||||
"乐", "于", "时", "傅",
|
||||
"皮", "卞", "齐", "康",
|
||||
"伍", "余", "元", "卜",
|
||||
"顾", "孟", "平", "黄",
|
||||
"和", "穆", "萧", "尹",
|
||||
"姚", "邵", "湛", "汪",
|
||||
"祁", "毛", "禹", "狄",
|
||||
"米", "贝", "明", "臧",
|
||||
"计", "伏", "成", "戴",
|
||||
"谈", "宋", "茅", "庞",
|
||||
"熊", "纪", "舒", "屈",
|
||||
"项", "祝", "董", "梁",
|
||||
"杜", "阮", "蓝", "闵",
|
||||
"席", "季", "麻", "强",
|
||||
"贾", "路", "娄", "危",
|
||||
"江", "童", "颜", "郭",
|
||||
"梅", "盛", "林", "刁",
|
||||
"钟", "徐", "邱", "骆",
|
||||
"高", "夏", "蔡", "田",
|
||||
"樊", "胡", "凌", "霍",
|
||||
"虞", "万", "支", "柯",
|
||||
"昝", "管", "卢", "莫",
|
||||
"经", "房", "裘", "缪",
|
||||
"干", "解", "应", "宗",
|
||||
"丁", "宣", "贲", "邓",
|
||||
"郁", "单", "杭", "洪",
|
||||
"包", "诸", "左", "石",
|
||||
"崔", "吉", "钮", "龚",
|
||||
"程", "嵇", "邢", "滑",
|
||||
"裴", "陆", "荣", "翁",
|
||||
"荀", "羊", "於", "惠",
|
||||
"甄", "曲", "家", "封",
|
||||
"芮", "羿", "储", "靳",
|
||||
"汲", "邴", "糜", "松",
|
||||
"井", "段", "富", "巫",
|
||||
"乌", "焦", "巴", "弓",
|
||||
"牧", "隗", "山", "谷",
|
||||
"车", "侯", "宓", "蓬",
|
||||
"全", "郗", "班", "仰",
|
||||
"秋", "仲", "伊", "宫",
|
||||
"宁", "仇", "栾", "暴",
|
||||
"甘", "钭", "厉", "戎",
|
||||
"祖", "武", "符", "刘",
|
||||
"景", "詹", "束", "龙",
|
||||
"叶", "幸", "司", "韶",
|
||||
"郜", "黎", "蓟", "薄",
|
||||
"印", "宿", "白", "怀",
|
||||
"蒲", "邰", "从", "鄂",
|
||||
"索", "咸", "籍", "赖",
|
||||
"卓", "蔺", "屠", "蒙",
|
||||
"池", "乔", "阴", "鬱",
|
||||
"胥", "能", "苍", "双",
|
||||
"闻", "莘", "党", "翟",
|
||||
"谭", "贡", "劳", "逄",
|
||||
"姬", "申", "扶", "堵",
|
||||
"冉", "宰", "郦", "雍",
|
||||
"郤", "璩", "桑", "桂",
|
||||
"濮", "牛", "寿", "通",
|
||||
"边", "扈", "燕", "冀",
|
||||
"郏", "浦", "尚", "农",
|
||||
"温", "别", "庄", "晏",
|
||||
"柴", "瞿", "阎", "充",
|
||||
"慕", "连", "茹", "习",
|
||||
"宦", "艾", "鱼", "容",
|
||||
"向", "古", "易", "慎",
|
||||
"戈", "廖", "庾", "终",
|
||||
"暨", "居", "衡", "步",
|
||||
"都", "耿", "满", "弘",
|
||||
"匡", "国", "文", "寇",
|
||||
"广", "禄", "阙", "东",
|
||||
"欧", "殳", "沃", "利",
|
||||
"蔚", "越", "夔", "隆",
|
||||
"师", "巩", "厍", "聂",
|
||||
"晁", "勾", "敖", "融",
|
||||
"冷", "訾", "辛", "阚",
|
||||
"那", "简", "饶", "空",
|
||||
"曾", "母", "沙", "乜",
|
||||
"养", "鞠", "须", "丰",
|
||||
"巢", "关", "蒯", "相",
|
||||
"查", "后", "荆", "红",
|
||||
"游", "竺", "权", "逯",
|
||||
"盖", "益", "桓", "公",
|
||||
"兰", "原", "乞", "西", "阿", "肖", "丑", "位", "曽", "巨", "德", "代", "圆", "尉", "仵", "纳", "仝", "脱",
|
||||
"丘", "但", "展", "迪", "付", "覃", "晗", "特", "隋", "苑", "奥", "漆", "谌", "郄", "练", "扎", "邝", "渠",
|
||||
"信", "门", "陳", "化", "原", "密", "泮", "鹿", "赫",
|
||||
"万俟", "司马", "上官", "欧阳",
|
||||
"夏侯", "诸葛", "闻人", "东方",
|
||||
"赫连", "皇甫", "尉迟", "公羊",
|
||||
"澹台", "公冶", "宗政", "濮阳",
|
||||
"淳于", "单于", "太叔", "申屠",
|
||||
"公孙", "仲孙", "轩辕", "令狐",
|
||||
"钟离", "宇文", "长孙", "慕容",
|
||||
"鲜于", "闾丘", "司徒", "司空",
|
||||
"亓官", "司寇", "仉督", "子车",
|
||||
"颛孙", "端木", "巫马", "公西",
|
||||
"漆雕", "乐正", "壤驷", "公良",
|
||||
"拓跋", "夹谷", "宰父", "榖梁",
|
||||
"晋", "楚", "闫", "法", "汝", "鄢", "涂", "钦",
|
||||
"段干", "百里", "东郭", "南门",
|
||||
"呼延", "归", "海", "羊舌", "微", "生",
|
||||
"岳", "帅", "缑", "亢", "况", "后", "有", "琴",
|
||||
"梁丘", "左丘", "东门", "西门",
|
||||
"商", "牟", "佘", "佴", "伯", "赏", "南宫",
|
||||
"墨", "哈", "谯", "笪", "年", "爱", "阳", "佟",
|
||||
"第五", "言", "福"])
|
||||
|
||||
def isit(n):return n.strip() in m
|
||||
|
||||
def isit(n): return n.strip() in m
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -108,13 +108,14 @@ class Dealer:
|
||||
if re.match(p, t):
|
||||
tk = "#"
|
||||
break
|
||||
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
# tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
if tk != "#" and tk:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def token_merge(self, tks):
|
||||
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
def one_term(t):
|
||||
return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
@ -152,8 +153,8 @@ class Dealer:
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:
|
||||
tks.append(t)
|
||||
@ -220,14 +221,15 @@ class Dealer:
|
||||
|
||||
return 3
|
||||
|
||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
def idf(s, N):
|
||||
return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
|
||||
tw = []
|
||||
if not preprocess:
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tks])
|
||||
np.array([ner(t) * postag(t) for t in tks])
|
||||
wts = [s for s in wts]
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
@ -236,7 +238,7 @@ class Dealer:
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
wts = [s for s in wts]
|
||||
tw.extend(zip(tt, wts))
|
||||
|
||||
|
||||
@ -3,4 +3,4 @@ from . import generator
|
||||
__all__ = [name for name in dir(generator)
|
||||
if not name.startswith('_')]
|
||||
|
||||
globals().update({name: getattr(generator, name) for name in __all__})
|
||||
globals().update({name: getattr(generator, name) for name in __all__})
|
||||
|
||||
@ -28,17 +28,16 @@ from rag.prompts.template import load_prompt
|
||||
from common.constants import TAG_FLD
|
||||
from common.token_utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
COMPLETE_TASK="complete_task"
|
||||
STOP_TOKEN = "<|STOP|>"
|
||||
COMPLETE_TASK = "complete_task"
|
||||
INPUT_UTILIZATION = 0.5
|
||||
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
|
||||
def chunks_format(reference):
|
||||
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
@ -126,7 +125,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
|
||||
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
|
||||
cnt += draw_node(k, v)
|
||||
cnt += "\n└── Content:\n"
|
||||
@ -173,7 +172,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
|
||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||
|
||||
|
||||
def citation_prompt(user_defined_prompts: dict={}) -> str:
|
||||
def citation_prompt(user_defined_prompts: dict = {}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
|
||||
return template.render()
|
||||
|
||||
@ -258,9 +257,11 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
|
||||
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
|
||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
|
||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,
|
||||
languages=languages)
|
||||
|
||||
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
|
||||
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}],
|
||||
{"temperature": 0.2})
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
return query
|
||||
@ -332,7 +333,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
|
||||
"properties": {
|
||||
"answer": {"type": "string", "description": "The final answer to the user's question"}},
|
||||
"required": ["answer"]
|
||||
}
|
||||
}
|
||||
@ -341,7 +343,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
name = tool["function"]["name"]
|
||||
desc[name] = tool
|
||||
|
||||
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
|
||||
return "\n\n".join([f"## {i + 1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in
|
||||
enumerate(desc.items())])
|
||||
|
||||
|
||||
def form_history(history, limit=-6):
|
||||
@ -350,14 +353,14 @@ def form_history(history, limit=-6):
|
||||
if h["role"] == "system":
|
||||
continue
|
||||
role = "USER"
|
||||
if h["role"].upper()!= role:
|
||||
if h["role"].upper() != role:
|
||||
role = "AGENT"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content']) > 2048 else '')}"
|
||||
return context
|
||||
|
||||
|
||||
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict],
|
||||
user_defined_prompts: dict = {}):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
@ -375,7 +378,8 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
|
||||
return kwd
|
||||
|
||||
|
||||
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
async def next_step_async(chat_mdl, history: list, tools_description: list[dict], task_desc,
|
||||
user_defined_prompts: dict = {}):
|
||||
if not tools_description:
|
||||
return "", 0
|
||||
desc = tool_schema(tools_description)
|
||||
@ -396,7 +400,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict = {}):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
||||
@ -419,7 +423,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
|
||||
|
||||
|
||||
def form_message(system_prompt, user_prompt):
|
||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||
return [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
|
||||
|
||||
|
||||
def structured_output_prompt(schema=None) -> str:
|
||||
@ -427,27 +431,29 @@ def structured_output_prompt(schema=None) -> str:
|
||||
return template.render(schema=schema)
|
||||
|
||||
|
||||
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
||||
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict = {}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||
system_prompt = template.render(name=name,
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
user_prompt = "→ Summary: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summaries: list[str],
|
||||
user_defined_prompts: dict = {}):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal,
|
||||
results=[{"i": i, "content": s} for i, s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||
async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict:
|
||||
meta_data_structure = {}
|
||||
for key, values in meta_data.items():
|
||||
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
|
||||
@ -471,13 +477,13 @@ async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||
return {"conditions": []}
|
||||
|
||||
|
||||
async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
||||
async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf=None):
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache
|
||||
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
|
||||
if cached:
|
||||
return json_repair.loads(cached)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
res = json_repair.loads(ans)
|
||||
@ -488,10 +494,13 @@ async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None
|
||||
|
||||
|
||||
TOC_DETECTION = load_prompt("toc_detection")
|
||||
async def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
||||
|
||||
|
||||
async def detect_table_of_contents(page_1024: list[str], chat_mdl):
|
||||
toc_secs = []
|
||||
for i, sec in enumerate(page_1024[:22]):
|
||||
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
|
||||
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.",
|
||||
chat_mdl)
|
||||
if toc_secs and not ans["exists"]:
|
||||
break
|
||||
toc_secs.append(sec)
|
||||
@ -500,14 +509,17 @@ async def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
||||
|
||||
TOC_EXTRACTION = load_prompt("toc_extraction")
|
||||
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
|
||||
|
||||
|
||||
async def extract_table_of_contents(toc_pages, chat_mdl):
|
||||
if not toc_pages:
|
||||
return []
|
||||
|
||||
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
|
||||
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)),
|
||||
"Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
async def toc_index_extractor(toc: list[dict], content: str, chat_mdl):
|
||||
tob_extractor_prompt = """
|
||||
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
|
||||
|
||||
@ -529,18 +541,21 @@ async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
If the title of the section are not in the provided pages, do not add the physical_index to it.
|
||||
Directly return the final JSON structure. Do not output anything else."""
|
||||
|
||||
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
|
||||
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False,
|
||||
indent=2) + '\nDocument pages:\n' + content
|
||||
return await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
TOC_INDEX = load_prompt("toc_index")
|
||||
|
||||
|
||||
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
||||
if not toc_arr or not sections:
|
||||
return []
|
||||
|
||||
toc_map = {}
|
||||
for i, it in enumerate(toc_arr):
|
||||
k1 = (it["structure"]+it["title"]).replace(" ", "")
|
||||
k1 = (it["structure"] + it["title"]).replace(" ", "")
|
||||
k2 = it["title"].strip()
|
||||
if k1 not in toc_map:
|
||||
toc_map[k1] = []
|
||||
@ -558,6 +573,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
toc_arr[j]["indices"].append(i)
|
||||
|
||||
all_pathes = []
|
||||
|
||||
def dfs(start, path):
|
||||
nonlocal all_pathes
|
||||
if start >= len(toc_arr):
|
||||
@ -565,7 +581,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
all_pathes.append(path)
|
||||
return
|
||||
if not toc_arr[start]["indices"]:
|
||||
dfs(start+1, path)
|
||||
dfs(start + 1, path)
|
||||
return
|
||||
added = False
|
||||
for j in toc_arr[start]["indices"]:
|
||||
@ -574,12 +590,12 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
_path = deepcopy(path)
|
||||
_path.append((j, start))
|
||||
added = True
|
||||
dfs(start+1, _path)
|
||||
dfs(start + 1, _path)
|
||||
if not added and path:
|
||||
all_pathes.append(path)
|
||||
|
||||
dfs(0, [])
|
||||
path = max(all_pathes, key=lambda x:len(x))
|
||||
path = max(all_pathes, key=lambda x: len(x))
|
||||
for it in toc_arr:
|
||||
it["indices"] = []
|
||||
for j, i in path:
|
||||
@ -588,24 +604,24 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
|
||||
i = 0
|
||||
while i < len(toc_arr):
|
||||
it = toc_arr[i]
|
||||
it = toc_arr[i]
|
||||
if it["indices"]:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if i>0 and toc_arr[i-1]["indices"]:
|
||||
st_i = toc_arr[i-1]["indices"][-1]
|
||||
if i > 0 and toc_arr[i - 1]["indices"]:
|
||||
st_i = toc_arr[i - 1]["indices"][-1]
|
||||
else:
|
||||
st_i = 0
|
||||
e = i + 1
|
||||
while e <len(toc_arr) and not toc_arr[e]["indices"]:
|
||||
while e < len(toc_arr) and not toc_arr[e]["indices"]:
|
||||
e += 1
|
||||
if e >= len(toc_arr):
|
||||
e = len(sections)
|
||||
else:
|
||||
e = toc_arr[e]["indices"][0]
|
||||
|
||||
for j in range(st_i, min(e+1, len(sections))):
|
||||
for j in range(st_i, min(e + 1, len(sections))):
|
||||
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
|
||||
structure=it["structure"],
|
||||
title=it["title"],
|
||||
@ -656,11 +672,15 @@ async def toc_transformer(toc_pages, chat_mdl):
|
||||
|
||||
toc_content = "\n".join(toc_pages)
|
||||
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
|
||||
|
||||
def clean_toc(arr):
|
||||
for a in arr:
|
||||
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
|
||||
|
||||
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content,
|
||||
json.dumps(last_complete, ensure_ascii=False, indent=2),
|
||||
chat_mdl)
|
||||
clean_toc(last_complete)
|
||||
if if_complete == "yes":
|
||||
return last_complete
|
||||
@ -682,13 +702,17 @@ async def toc_transformer(toc_pages, chat_mdl):
|
||||
break
|
||||
clean_toc(new_complete)
|
||||
last_complete.extend(new_complete)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content,
|
||||
json.dumps(last_complete, ensure_ascii=False,
|
||||
indent=2), chat_mdl)
|
||||
|
||||
return last_complete
|
||||
|
||||
|
||||
TOC_LEVELS = load_prompt("assign_toc_levels")
|
||||
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
||||
|
||||
|
||||
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}):
|
||||
if not toc_secs:
|
||||
return []
|
||||
return await gen_json(
|
||||
@ -701,12 +725,15 @@ async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2})
|
||||
|
||||
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
|
||||
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
|
||||
|
||||
|
||||
# Generate TOC from text chunks with text llms
|
||||
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
|
||||
try:
|
||||
ans = await gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(
|
||||
text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
@ -743,7 +770,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
|
||||
)
|
||||
|
||||
input_budget = 1024 if input_budget > 1024 else input_budget
|
||||
input_budget = 1024 if input_budget > 1024 else input_budget
|
||||
chunk_sections = split_chunks(chunks, input_budget)
|
||||
titles = []
|
||||
|
||||
@ -798,7 +825,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
if sorted_list:
|
||||
max_lvl = sorted_list[-1]
|
||||
merged = []
|
||||
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
for _, (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
if prune and toc_item.get("level", "0") >= max_lvl:
|
||||
continue
|
||||
merged.append({
|
||||
@ -812,12 +839,15 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
|
||||
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
|
||||
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
|
||||
async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
|
||||
|
||||
|
||||
async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
|
||||
import numpy as np
|
||||
try:
|
||||
ans = await gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n" % "\n".join(
|
||||
[json.dumps({"level": d["level"], "title": d["title"]}, ensure_ascii=False) for d in toc])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
@ -828,17 +858,19 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
|
||||
for id in ti.get("ids", []):
|
||||
if id not in id2score:
|
||||
id2score[id] = []
|
||||
id2score[id].append(sc["score"]/5.)
|
||||
id2score[id].append(sc["score"] / 5.)
|
||||
for id in id2score.keys():
|
||||
id2score[id] = np.mean(id2score[id])
|
||||
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
|
||||
return [(id, sc) for id, sc in list(id2score.items()) if sc >= 0.3][:topn]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
|
||||
META_DATA = load_prompt("meta_data")
|
||||
async def gen_metadata(chat_mdl, schema:dict, content:str):
|
||||
|
||||
|
||||
async def gen_metadata(chat_mdl, schema: dict, content: str):
|
||||
template = PROMPT_JINJA_ENV.from_string(META_DATA)
|
||||
for k, desc in schema["properties"].items():
|
||||
if "enum" in desc and not desc.get("enum"):
|
||||
@ -849,4 +881,4 @@ async def gen_metadata(chat_mdl, schema:dict, content:str):
|
||||
user_prompt = "Output: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
|
||||
PROMPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
_loaded_prompts = {}
|
||||
|
||||
@ -48,13 +48,15 @@ def main():
|
||||
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
||||
logging.info("CACHE: {}".format(loc))
|
||||
except Exception as e:
|
||||
traceback.print_stack(e)
|
||||
logging.error(f"Error to get data from REDIS: {e}")
|
||||
traceback.print_stack()
|
||||
except Exception as e:
|
||||
traceback.print_stack(e)
|
||||
logging.error(f"Error to check REDIS connection: {e}")
|
||||
traceback.print_stack()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
main()
|
||||
close_connection()
|
||||
time.sleep(1)
|
||||
time.sleep(1)
|
||||
|
||||
@ -19,16 +19,15 @@ import requests
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
|
||||
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
|
||||
|
||||
JSON_DATA = {
|
||||
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
|
||||
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
|
||||
"word": "" # User question, don't need to initialize
|
||||
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
|
||||
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
|
||||
"word": "" # User question, don't need to initialize
|
||||
}
|
||||
|
||||
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application
|
||||
|
||||
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
@ -50,7 +49,7 @@ async def on_message(message):
|
||||
if len(message.content.split('> ')) == 1:
|
||||
await message.channel.send("Hi~ How can I help you? ")
|
||||
else:
|
||||
JSON_DATA['word']=message.content.split('> ')[1]
|
||||
JSON_DATA['word'] = message.content.split('> ')[1]
|
||||
response = requests.post(URL, json=JSON_DATA)
|
||||
response_data = response.json().get('data', [])
|
||||
image_bool = False
|
||||
@ -61,9 +60,9 @@ async def on_message(message):
|
||||
if i['type'] == 3:
|
||||
image_bool = True
|
||||
image_data = base64.b64decode(i['url'])
|
||||
with open('tmp_image.png','wb') as file:
|
||||
with open('tmp_image.png', 'wb') as file:
|
||||
file.write(image_data)
|
||||
image= discord.File('tmp_image.png')
|
||||
image = discord.File('tmp_image.png')
|
||||
|
||||
await message.channel.send(f"{message.author.mention}{res}")
|
||||
|
||||
|
||||
@ -38,7 +38,8 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common import settings
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, \
|
||||
MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
|
||||
from common.constants import FileSource, TaskStatus
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
@ -96,7 +97,7 @@ class SyncBase:
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
for document_batch in document_batch_generator:
|
||||
for document_batch in document_batch_generator:
|
||||
if not document_batch:
|
||||
continue
|
||||
|
||||
@ -161,6 +162,7 @@ class SyncBase:
|
||||
def _get_source_prefix(self):
|
||||
return ""
|
||||
|
||||
|
||||
class _BlobLikeBase(SyncBase):
|
||||
DEFAULT_BUCKET_TYPE: str = "s3"
|
||||
|
||||
@ -199,22 +201,27 @@ class _BlobLikeBase(SyncBase):
|
||||
)
|
||||
return document_batch_generator
|
||||
|
||||
|
||||
class S3(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.S3
|
||||
DEFAULT_BUCKET_TYPE: str = "s3"
|
||||
|
||||
|
||||
class R2(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.R2
|
||||
DEFAULT_BUCKET_TYPE: str = "r2"
|
||||
|
||||
|
||||
class OCI_STORAGE(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.OCI_STORAGE
|
||||
DEFAULT_BUCKET_TYPE: str = "oci_storage"
|
||||
|
||||
|
||||
class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE
|
||||
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.CONFLUENCE
|
||||
|
||||
@ -248,7 +255,9 @@ class Confluence(SyncBase):
|
||||
index_recursively=index_recursively,
|
||||
)
|
||||
|
||||
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
|
||||
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"],
|
||||
connector_name=DocumentSource.CONFLUENCE,
|
||||
credential_json=self.conf["credentials"])
|
||||
self.connector.set_credentials_provider(credentials_provider)
|
||||
|
||||
# Determine the time range for synchronization based on reindex or poll_range_start
|
||||
@ -280,7 +289,8 @@ class Confluence(SyncBase):
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
logging.warning("Confluence connector failure: %s",
|
||||
getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
@ -300,7 +310,7 @@ class Confluence(SyncBase):
|
||||
async def async_wrapper():
|
||||
for batch in document_batches():
|
||||
yield batch
|
||||
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
return async_wrapper()
|
||||
|
||||
@ -314,10 +324,12 @@ class Notion(SyncBase):
|
||||
document_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(),
|
||||
datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
|
||||
task["poll_range_start"])
|
||||
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
|
||||
return document_generator
|
||||
|
||||
@ -340,10 +352,12 @@ class Discord(SyncBase):
|
||||
document_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(),
|
||||
datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
|
||||
task["poll_range_start"])
|
||||
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
|
||||
return document_generator
|
||||
|
||||
@ -485,7 +499,8 @@ class GoogleDrive(SyncBase):
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
logging.warning("Google Drive connector failure: %s",
|
||||
getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
@ -646,10 +661,10 @@ class WebDAV(SyncBase):
|
||||
remote_path=self.conf.get("remote_path", "/")
|
||||
)
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
|
||||
|
||||
logging.info(f"Task info: reindex={task['reindex']}, poll_range_start={task['poll_range_start']}")
|
||||
|
||||
if task["reindex"]=="1" or not task["poll_range_start"]:
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
logging.info("Using load_from_state (full sync)")
|
||||
document_batch_generator = self.connector.load_from_state()
|
||||
begin_info = "totally"
|
||||
@ -659,14 +674,15 @@ class WebDAV(SyncBase):
|
||||
logging.info(f"Polling WebDAV from {task['poll_range_start']} (ts: {start_ts}) to now (ts: {end_ts})")
|
||||
document_batch_generator = self.connector.poll_source(start_ts, end_ts)
|
||||
begin_info = "from {}".format(task["poll_range_start"])
|
||||
|
||||
|
||||
logging.info("Connect to WebDAV: {}(path: {}) {}".format(
|
||||
self.conf["base_url"],
|
||||
self.conf.get("remote_path", "/"),
|
||||
begin_info
|
||||
))
|
||||
return document_batch_generator
|
||||
|
||||
|
||||
|
||||
class Moodle(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.MOODLE
|
||||
|
||||
@ -675,7 +691,7 @@ class Moodle(SyncBase):
|
||||
moodle_url=self.conf["moodle_url"],
|
||||
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)
|
||||
)
|
||||
|
||||
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
|
||||
# Determine the time range for synchronization based on reindex or poll_range_start
|
||||
@ -689,7 +705,7 @@ class Moodle(SyncBase):
|
||||
begin_info = "totally"
|
||||
else:
|
||||
document_generator = self.connector.poll_source(
|
||||
poll_start.timestamp(),
|
||||
poll_start.timestamp(),
|
||||
datetime.now(timezone.utc).timestamp()
|
||||
)
|
||||
begin_info = "from {}".format(poll_start)
|
||||
@ -718,7 +734,7 @@ class BOX(SyncBase):
|
||||
token = AccessToken(
|
||||
access_token=credential['access_token'],
|
||||
refresh_token=credential['refresh_token'],
|
||||
)
|
||||
)
|
||||
auth.token_storage.store(token)
|
||||
|
||||
self.connector.load_credentials(auth)
|
||||
@ -739,6 +755,7 @@ class BOX(SyncBase):
|
||||
logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info))
|
||||
return document_generator
|
||||
|
||||
|
||||
class Airtable(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.AIRTABLE
|
||||
|
||||
@ -784,6 +801,7 @@ class Airtable(SyncBase):
|
||||
|
||||
return document_generator
|
||||
|
||||
|
||||
func_factory = {
|
||||
FileSource.S3: S3,
|
||||
FileSource.R2: R2,
|
||||
|
||||
@ -92,7 +92,7 @@ FACTORY = {
|
||||
}
|
||||
|
||||
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
||||
"dataflow" : PipelineTaskType.PARSE,
|
||||
"dataflow": PipelineTaskType.PARSE,
|
||||
"raptor": PipelineTaskType.RAPTOR,
|
||||
"graphrag": PipelineTaskType.GRAPH_RAG,
|
||||
"mindmap": PipelineTaskType.MINDMAP,
|
||||
@ -221,7 +221,7 @@ async def get_storage_binary(bucket, name):
|
||||
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
|
||||
|
||||
|
||||
@timeout(60*80, 1)
|
||||
@timeout(60 * 80, 1)
|
||||
async def build_chunks(task, progress_callback):
|
||||
if task["size"] > settings.DOC_MAXIMUM_SIZE:
|
||||
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||
@ -283,7 +283,8 @@ async def build_chunks(task, progress_callback):
|
||||
try:
|
||||
d = copy.deepcopy(document)
|
||||
d.update(chunk)
|
||||
d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
d["id"] = xxhash.xxh64(
|
||||
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
if not d.get("image"):
|
||||
@ -328,9 +329,11 @@ async def build_chunks(task, progress_callback):
|
||||
d["important_kwd"] = cached.split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
return
|
||||
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
||||
tasks.append(
|
||||
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
@ -355,9 +358,11 @@ async def build_chunks(task, progress_callback):
|
||||
if cached:
|
||||
d["question_kwd"] = cached.split("\n")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
||||
tasks.append(
|
||||
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
@ -374,15 +379,18 @@ async def build_chunks(task, progress_callback):
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
|
||||
async def gen_metadata_task(chat_mdl, d):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", task["parser_config"]["metadata"])
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
|
||||
task["parser_config"]["metadata"])
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await gen_metadata(chat_mdl,
|
||||
metadata_schema(task["parser_config"]["metadata"]),
|
||||
d["content_with_weight"])
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", task["parser_config"]["metadata"])
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
|
||||
task["parser_config"]["metadata"])
|
||||
if cached:
|
||||
d["metadata_obj"] = cached
|
||||
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
|
||||
@ -430,7 +438,8 @@ async def build_chunks(task, progress_callback):
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return None
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
|
||||
d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
docs_to_tag.append(d)
|
||||
@ -438,7 +447,7 @@ async def build_chunks(task, progress_callback):
|
||||
async def doc_content_tagging(chat_mdl, d, topn_tags):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
||||
if not cached:
|
||||
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
|
||||
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
|
||||
if not picked_examples:
|
||||
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
|
||||
async with chat_limiter:
|
||||
@ -454,6 +463,7 @@ async def build_chunks(task, progress_callback):
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||
d[TAG_FLD] = json.loads(cached)
|
||||
|
||||
tasks = []
|
||||
for d in docs_to_tag:
|
||||
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
|
||||
@ -473,21 +483,22 @@ async def build_chunks(task, progress_callback):
|
||||
def build_TOC(task, docs, progress_callback):
|
||||
progress_callback(msg="Start to generate table of content ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
docs = sorted(docs, key=lambda d:(
|
||||
docs = sorted(docs, key=lambda d: (
|
||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
))
|
||||
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
toc: list[dict] = asyncio.run(
|
||||
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
||||
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
ii = 0
|
||||
while ii < len(toc):
|
||||
try:
|
||||
idx = int(toc[ii]["chunk_id"])
|
||||
del toc[ii]["chunk_id"]
|
||||
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||
if ii == len(toc) -1:
|
||||
if ii == len(toc) - 1:
|
||||
break
|
||||
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||
for jj in range(idx + 1, int(toc[ii + 1]["chunk_id"]) + 1):
|
||||
toc[ii]["ids"].append(docs[jj]["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
@ -499,7 +510,8 @@ def build_TOC(task, docs, progress_callback):
|
||||
d["toc_kwd"] = "toc"
|
||||
d["available_int"] = 0
|
||||
d["page_num_int"] = [100000000]
|
||||
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
d["id"] = xxhash.xxh64(
|
||||
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
return d
|
||||
return None
|
||||
|
||||
@ -532,12 +544,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal mdl
|
||||
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
|
||||
return mdl.encode([truncate(c, mdl.max_length - 10) for c in txts])
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -545,7 +557,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
tk_count += c
|
||||
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
|
||||
cnts = cnts_
|
||||
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
|
||||
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
|
||||
if not filename_embd_weight:
|
||||
filename_embd_weight = 0.1
|
||||
title_w = float(filename_embd_weight)
|
||||
@ -588,7 +600,8 @@ async def run_dataflow(task: dict):
|
||||
return
|
||||
|
||||
if not chunks:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||
@ -610,25 +623,27 @@ async def run_dataflow(task: dict):
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1)
|
||||
delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}")
|
||||
if i % (len(texts) // settings.EMBEDDING_BATCH_SIZE / 100 + 1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i + 1} / {len(texts) // settings.EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
@ -636,10 +651,10 @@ async def run_dataflow(task: dict):
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
except Exception as e:
|
||||
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
|
||||
metadata = {}
|
||||
for ck in chunks:
|
||||
ck["doc_id"] = doc_id
|
||||
@ -686,15 +701,19 @@ async def run_dataflow(task: dict):
|
||||
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
|
||||
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
|
||||
task_time_cost)
|
||||
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
|
||||
task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
|
||||
dsl=str(pipeline))
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -702,7 +721,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
|
||||
raptor_config = kb_parser_config.get("raptor", {})
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
vctr_nm = "q_%d_vec" % vector_size
|
||||
|
||||
res = []
|
||||
tk_count = 0
|
||||
@ -747,17 +766,17 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
for x, doc_id in enumerate(doc_ids):
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
await generate(chunks, doc_id)
|
||||
callback(prog=(x+1.)/len(doc_ids))
|
||||
callback(prog=(x + 1.) / len(doc_ids))
|
||||
else:
|
||||
chunks = []
|
||||
for doc_id in doc_ids:
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
await generate(chunks, fake_doc_id)
|
||||
@ -792,19 +811,22 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
mom_ck["available_int"] = 0
|
||||
flds = list(mom_ck.keys())
|
||||
for fld in flds:
|
||||
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int"]:
|
||||
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
|
||||
"position_int"]:
|
||||
del mom_ck[fld]
|
||||
mothers.append(mom_ck)
|
||||
|
||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return False
|
||||
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -821,7 +843,8 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
tasks = []
|
||||
for chunk_id in chunk_ids:
|
||||
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
|
||||
@ -838,7 +861,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
return True
|
||||
|
||||
|
||||
@timeout(60*60*3, 1)
|
||||
@timeout(60 * 60 * 3, 1)
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
@ -914,7 +937,7 @@ async def do_handle_task(task):
|
||||
},
|
||||
}
|
||||
)
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||
return
|
||||
|
||||
@ -943,7 +966,7 @@ async def do_handle_task(task):
|
||||
doc_ids=task.get("doc_ids", []),
|
||||
)
|
||||
if fake_doc_ids := task.get("doc_ids", []):
|
||||
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
|
||||
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task_type == "graphrag":
|
||||
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
||||
@ -968,11 +991,10 @@ async def do_handle_task(task):
|
||||
}
|
||||
}
|
||||
)
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
|
||||
return
|
||||
|
||||
|
||||
graphrag_conf = kb_parser_config.get("graphrag", {})
|
||||
start_ts = timer()
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
@ -1030,7 +1052,7 @@ async def do_handle_task(task):
|
||||
return True
|
||||
e = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
|
||||
return bool(e)
|
||||
|
||||
|
||||
try:
|
||||
if not await _maybe_insert_es(chunks):
|
||||
return
|
||||
@ -1084,8 +1106,8 @@ async def do_handle_task(task):
|
||||
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled."
|
||||
)
|
||||
|
||||
async def handle_task():
|
||||
|
||||
async def handle_task():
|
||||
global DONE_TASKS, FAILED_TASKS
|
||||
redis_msg, task = await collect()
|
||||
if not task:
|
||||
@ -1093,7 +1115,8 @@ async def handle_task():
|
||||
return
|
||||
|
||||
task_type = task["task_type"]
|
||||
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
||||
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
|
||||
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
||||
|
||||
try:
|
||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||
@ -1119,7 +1142,9 @@ async def handle_task():
|
||||
if task_type in ["graphrag", "raptor", "mindmap"]:
|
||||
task_document_ids = task["doc_ids"]
|
||||
if not task.get("dataflow_id", ""):
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
|
||||
task_type=pipeline_task_type,
|
||||
fake_document_ids=task_document_ids)
|
||||
|
||||
redis_msg.ack()
|
||||
|
||||
@ -1249,6 +1274,7 @@ async def main():
|
||||
await asyncio.gather(report_task, return_exceptions=True)
|
||||
logging.error("BUG!!! You should not reach here!!!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
faulthandler.enable()
|
||||
init_root_logger(CONSUMER_NAME)
|
||||
|
||||
@ -42,8 +42,10 @@ class RAGFlowAzureSpnBlob:
|
||||
pass
|
||||
|
||||
try:
|
||||
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
|
||||
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials)
|
||||
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id,
|
||||
client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
|
||||
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name,
|
||||
credential=credentials)
|
||||
except Exception:
|
||||
logging.exception("Fail to connect %s" % self.account_url)
|
||||
|
||||
@ -104,4 +106,4 @@ class RAGFlowAzureSpnBlob:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return None
|
||||
return None
|
||||
|
||||
@ -25,7 +25,8 @@ from PIL import Image
|
||||
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str = "imagetemps"):
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from rag.svr.task_executor import minio_limiter
|
||||
@ -74,7 +75,7 @@ async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="
|
||||
del d["image"]
|
||||
|
||||
|
||||
def id2image(image_id:str|None, storage_get_func: partial):
|
||||
def id2image(image_id: str | None, storage_get_func: partial):
|
||||
if not image_id:
|
||||
return
|
||||
arr = image_id.split("-")
|
||||
|
||||
@ -16,11 +16,13 @@
|
||||
|
||||
import logging
|
||||
from common.crypto_utils import CryptoUtil
|
||||
|
||||
|
||||
# from common.decorator import singleton
|
||||
|
||||
class EncryptedStorageWrapper:
|
||||
"""Encrypted storage wrapper that wraps existing storage implementations to provide transparent encryption"""
|
||||
|
||||
|
||||
def __init__(self, storage_impl, algorithm="aes-256-cbc", key=None, iv=None):
|
||||
"""
|
||||
Initialize encrypted storage wrapper
|
||||
@ -34,16 +36,16 @@ class EncryptedStorageWrapper:
|
||||
self.storage_impl = storage_impl
|
||||
self.crypto = CryptoUtil(algorithm=algorithm, key=key, iv=iv)
|
||||
self.encryption_enabled = True
|
||||
|
||||
|
||||
# Check if storage implementation has required methods
|
||||
# todo: Consider abstracting a storage base class to ensure these methods exist
|
||||
required_methods = ["put", "get", "rm", "obj_exist", "health"]
|
||||
for method in required_methods:
|
||||
if not hasattr(storage_impl, method):
|
||||
raise AttributeError(f"Storage implementation missing required method: {method}")
|
||||
|
||||
|
||||
logging.info(f"EncryptedStorageWrapper initialized with algorithm: {algorithm}")
|
||||
|
||||
|
||||
def put(self, bucket, fnm, binary, tenant_id=None):
|
||||
"""
|
||||
Encrypt and store data
|
||||
@ -59,15 +61,15 @@ class EncryptedStorageWrapper:
|
||||
"""
|
||||
if not self.encryption_enabled:
|
||||
return self.storage_impl.put(bucket, fnm, binary, tenant_id)
|
||||
|
||||
|
||||
try:
|
||||
encrypted_binary = self.crypto.encrypt(binary)
|
||||
|
||||
|
||||
return self.storage_impl.put(bucket, fnm, encrypted_binary, tenant_id)
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to encrypt and store data: {bucket}/{fnm}, error: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get(self, bucket, fnm, tenant_id=None):
|
||||
"""
|
||||
Retrieve and decrypt data
|
||||
@ -83,21 +85,21 @@ class EncryptedStorageWrapper:
|
||||
try:
|
||||
# Get encrypted data
|
||||
encrypted_binary = self.storage_impl.get(bucket, fnm, tenant_id)
|
||||
|
||||
|
||||
if encrypted_binary is None:
|
||||
return None
|
||||
|
||||
|
||||
if not self.encryption_enabled:
|
||||
return encrypted_binary
|
||||
|
||||
|
||||
# Decrypt data
|
||||
decrypted_binary = self.crypto.decrypt(encrypted_binary)
|
||||
return decrypted_binary
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to get and decrypt data: {bucket}/{fnm}, error: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def rm(self, bucket, fnm, tenant_id=None):
|
||||
"""
|
||||
Delete data (same as original storage implementation, no decryption needed)
|
||||
@ -111,7 +113,7 @@ class EncryptedStorageWrapper:
|
||||
Deletion result
|
||||
"""
|
||||
return self.storage_impl.rm(bucket, fnm, tenant_id)
|
||||
|
||||
|
||||
def obj_exist(self, bucket, fnm, tenant_id=None):
|
||||
"""
|
||||
Check if object exists (same as original storage implementation, no decryption needed)
|
||||
@ -125,7 +127,7 @@ class EncryptedStorageWrapper:
|
||||
Whether the object exists
|
||||
"""
|
||||
return self.storage_impl.obj_exist(bucket, fnm, tenant_id)
|
||||
|
||||
|
||||
def health(self):
|
||||
"""
|
||||
Health check (uses the original storage implementation's method)
|
||||
@ -134,7 +136,7 @@ class EncryptedStorageWrapper:
|
||||
Health check result
|
||||
"""
|
||||
return self.storage_impl.health()
|
||||
|
||||
|
||||
def bucket_exists(self, bucket):
|
||||
"""
|
||||
Check if bucket exists (if the original storage implementation has this method)
|
||||
@ -148,7 +150,7 @@ class EncryptedStorageWrapper:
|
||||
if hasattr(self.storage_impl, "bucket_exists"):
|
||||
return self.storage_impl.bucket_exists(bucket)
|
||||
return False
|
||||
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
|
||||
"""
|
||||
Get presigned URL (if the original storage implementation has this method)
|
||||
@ -165,7 +167,7 @@ class EncryptedStorageWrapper:
|
||||
if hasattr(self.storage_impl, "get_presigned_url"):
|
||||
return self.storage_impl.get_presigned_url(bucket, fnm, expires, tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
def scan(self, bucket, fnm, tenant_id=None):
|
||||
"""
|
||||
Scan objects (if the original storage implementation has this method)
|
||||
@ -181,7 +183,7 @@ class EncryptedStorageWrapper:
|
||||
if hasattr(self.storage_impl, "scan"):
|
||||
return self.storage_impl.scan(bucket, fnm, tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
|
||||
"""
|
||||
Copy object (if the original storage implementation has this method)
|
||||
@ -198,7 +200,7 @@ class EncryptedStorageWrapper:
|
||||
if hasattr(self.storage_impl, "copy"):
|
||||
return self.storage_impl.copy(src_bucket, src_path, dest_bucket, dest_path)
|
||||
return False
|
||||
|
||||
|
||||
def move(self, src_bucket, src_path, dest_bucket, dest_path):
|
||||
"""
|
||||
Move object (if the original storage implementation has this method)
|
||||
@ -215,7 +217,7 @@ class EncryptedStorageWrapper:
|
||||
if hasattr(self.storage_impl, "move"):
|
||||
return self.storage_impl.move(src_bucket, src_path, dest_bucket, dest_path)
|
||||
return False
|
||||
|
||||
|
||||
def remove_bucket(self, bucket):
|
||||
"""
|
||||
Remove bucket (if the original storage implementation has this method)
|
||||
@ -229,17 +231,18 @@ class EncryptedStorageWrapper:
|
||||
if hasattr(self.storage_impl, "remove_bucket"):
|
||||
return self.storage_impl.remove_bucket(bucket)
|
||||
return False
|
||||
|
||||
|
||||
def enable_encryption(self):
|
||||
"""Enable encryption"""
|
||||
self.encryption_enabled = True
|
||||
logging.info("Encryption enabled")
|
||||
|
||||
|
||||
def disable_encryption(self):
|
||||
"""Disable encryption"""
|
||||
self.encryption_enabled = False
|
||||
logging.info("Encryption disabled")
|
||||
|
||||
|
||||
# Create singleton wrapper function
|
||||
def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True):
|
||||
"""
|
||||
@ -255,12 +258,12 @@ def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_
|
||||
Encrypted storage wrapper instance
|
||||
"""
|
||||
wrapper = EncryptedStorageWrapper(storage_impl, algorithm=algorithm, key=key)
|
||||
|
||||
|
||||
wrapper.encryption_enabled = encryption_enabled
|
||||
|
||||
|
||||
if encryption_enabled:
|
||||
logging.info("Encryption enabled in storage wrapper")
|
||||
else:
|
||||
logging.info("Encryption disabled in storage wrapper")
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -32,7 +32,6 @@ ATTEMPT_TIME = 2
|
||||
|
||||
@singleton
|
||||
class ESConnection(ESConnectionBase):
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
@ -82,8 +81,9 @@ class ESConnection(ESConnectionBase):
|
||||
vector_similarity_weight = 0.5
|
||||
for m in match_expressions:
|
||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
|
||||
match_expressions[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
match_expressions[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||
@ -93,9 +93,9 @@ class ESConnection(ESConnectionBase):
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bool_query.must.append(Q("query_string", fields=m.fields,
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
bool_query.boost = 1.0 - vector_similarity_weight
|
||||
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
@ -146,7 +146,7 @@ class ESConnection(ESConnectionBase):
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
#print(json.dumps(q, ensure_ascii=False))
|
||||
# print(json.dumps(q, ensure_ascii=False))
|
||||
res = self.es.search(index=index_names,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
@ -220,13 +220,15 @@ class ESConnection(ESConnectionBase):
|
||||
try:
|
||||
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
|
||||
except Exception:
|
||||
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
self.logger.exception(
|
||||
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
try:
|
||||
self.es.update(index=index_name, id=chunk_id, doc=doc)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.exception(
|
||||
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
|
||||
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(
|
||||
e))
|
||||
break
|
||||
return False
|
||||
|
||||
|
||||
@ -25,18 +25,23 @@ import PyPDF2
|
||||
from docx import Document
|
||||
import olefile
|
||||
|
||||
|
||||
def _is_zip(h: bytes) -> bool:
|
||||
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
|
||||
|
||||
|
||||
def _is_pdf(h: bytes) -> bool:
|
||||
return h.startswith(b"%PDF-")
|
||||
|
||||
|
||||
def _is_ole(h: bytes) -> bool:
|
||||
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
|
||||
|
||||
|
||||
def _sha10(b: bytes) -> str:
|
||||
return hashlib.sha256(b).hexdigest()[:10]
|
||||
|
||||
|
||||
def _guess_ext(b: bytes) -> str:
|
||||
h = b[:8]
|
||||
if _is_zip(h):
|
||||
@ -58,13 +63,14 @@ def _guess_ext(b: bytes) -> str:
|
||||
return ".doc"
|
||||
return ".bin"
|
||||
|
||||
|
||||
# Try to extract the real embedded payload from OLE's Ole10Native
|
||||
def _extract_ole10native_payload(data: bytes) -> bytes:
|
||||
try:
|
||||
pos = 0
|
||||
if len(data) < 4:
|
||||
return data
|
||||
_ = int.from_bytes(data[pos:pos+4], "little")
|
||||
_ = int.from_bytes(data[pos:pos + 4], "little")
|
||||
pos += 4
|
||||
# filename/src/tmp (NUL-terminated ANSI)
|
||||
for _ in range(3):
|
||||
@ -74,14 +80,15 @@ def _extract_ole10native_payload(data: bytes) -> bytes:
|
||||
pos += 4
|
||||
if pos + 4 > len(data):
|
||||
return data
|
||||
size = int.from_bytes(data[pos:pos+4], "little")
|
||||
size = int.from_bytes(data[pos:pos + 4], "little")
|
||||
pos += 4
|
||||
if pos + size <= len(data):
|
||||
return data[pos:pos+size]
|
||||
return data[pos:pos + size]
|
||||
except Exception:
|
||||
pass
|
||||
return data
|
||||
|
||||
|
||||
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
|
||||
"""
|
||||
Only extract the 'first layer' of embedding, returning raw (filename, bytes).
|
||||
@ -163,7 +170,7 @@ def extract_links_from_docx(docx_bytes: bytes):
|
||||
# Each relationship may represent a hyperlink, image, footer, etc.
|
||||
for rel in document.part.rels.values():
|
||||
if rel.reltype == (
|
||||
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
|
||||
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
|
||||
):
|
||||
links.add(rel.target_ref)
|
||||
|
||||
@ -198,6 +205,8 @@ def extract_links_from_pdf(pdf_bytes: bytes):
|
||||
|
||||
|
||||
_GLOBAL_SESSION: Optional[requests.Session] = None
|
||||
|
||||
|
||||
def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
|
||||
"""Get or create a global reusable session."""
|
||||
global _GLOBAL_SESSION
|
||||
@ -216,10 +225,10 @@ def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
|
||||
|
||||
|
||||
def extract_html(
|
||||
url: str,
|
||||
timeout: float = 60.0,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 2,
|
||||
url: str,
|
||||
timeout: float = 60.0,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 2,
|
||||
) -> Tuple[Optional[bytes], Dict[str, str]]:
|
||||
"""
|
||||
Extract the full HTML page as raw bytes from a given URL.
|
||||
@ -260,4 +269,4 @@ def extract_html(
|
||||
metadata["error"] = f"Request failed: {e}"
|
||||
continue
|
||||
|
||||
return None, metadata
|
||||
return None, metadata
|
||||
|
||||
@ -204,4 +204,4 @@ class RAGFlowGCS:
|
||||
return False
|
||||
except Exception:
|
||||
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
|
||||
return False
|
||||
return False
|
||||
|
||||
@ -28,7 +28,6 @@ from common.doc_store.infinity_conn_base import InfinityConnectionBase
|
||||
|
||||
@singleton
|
||||
class InfinityConnection(InfinityConnectionBase):
|
||||
|
||||
"""
|
||||
Dataframe and fields convert
|
||||
"""
|
||||
@ -83,24 +82,23 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
knowledgebase_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
self,
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
knowledgebase_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
"""
|
||||
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
|
||||
@ -159,7 +157,8 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
if table_found:
|
||||
break
|
||||
if not table_found:
|
||||
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
|
||||
self.logger.error(
|
||||
f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
|
||||
return pd.DataFrame(), 0
|
||||
|
||||
for matchExpr in match_expressions:
|
||||
@ -280,7 +279,8 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
self.logger.warning(
|
||||
f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
continue
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
|
||||
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||
@ -288,7 +288,9 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = self.concat_dataframes(df_list, ["id"])
|
||||
fields = set(res.columns.tolist())
|
||||
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
|
||||
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd",
|
||||
"question_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks",
|
||||
"authors_sm_tks"]:
|
||||
fields.add(field)
|
||||
res_fields = self.get_fields(res, list(fields))
|
||||
return res_fields.get(chunk_id, None)
|
||||
@ -379,7 +381,9 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
d[k] = "_".join(f"{num:08x}" for num in v)
|
||||
else:
|
||||
d[k] = v
|
||||
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
|
||||
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
|
||||
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd",
|
||||
"question_tks"]:
|
||||
if k in d:
|
||||
del d[k]
|
||||
|
||||
@ -478,7 +482,8 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
del new_value[k]
|
||||
else:
|
||||
new_value[k] = v
|
||||
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
|
||||
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
|
||||
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
|
||||
if k in new_value:
|
||||
del new_value[k]
|
||||
|
||||
@ -502,7 +507,8 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
|
||||
for update_kv, ids in remove_opt.items():
|
||||
k, v = json.loads(update_kv)
|
||||
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
|
||||
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])),
|
||||
{k: "###".join(v)})
|
||||
|
||||
table_instance.update(filter, new_value)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
@ -561,7 +567,7 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
def to_position_int(v):
|
||||
if v:
|
||||
arr = [int(hex_val, 16) for hex_val in v.split("_")]
|
||||
v = [arr[i : i + 5] for i in range(0, len(arr), 5)]
|
||||
v = [arr[i: i + 5] for i in range(0, len(arr), 5)]
|
||||
else:
|
||||
v = []
|
||||
return v
|
||||
|
||||
@ -46,6 +46,7 @@ class RAGFlowMinio:
|
||||
# pass original identifier forward for use by other decorators
|
||||
kwargs['_orig_bucket'] = original_bucket
|
||||
return method(self, actual_bucket, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
@ -71,6 +72,7 @@ class RAGFlowMinio:
|
||||
fnm = f"{orig_bucket}/{fnm}"
|
||||
|
||||
return method(self, bucket, fnm, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def __open__(self):
|
||||
|
||||
@ -37,7 +37,8 @@ from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common.decorator import singleton
|
||||
from common.float_utils import get_float
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
|
||||
MatchDenseExpr
|
||||
from rag.nlp import rag_tokenizer
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
@ -719,19 +720,19 @@ class OBConnection(DocStoreConnection):
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None,
|
||||
**kwargs,
|
||||
self,
|
||||
selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
@ -1546,7 +1547,7 @@ class OBConnection(DocStoreConnection):
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
if len(re.findall(r'</em><em>', highlighted_txt)) > 0 or len(
|
||||
re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
|
||||
re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
|
||||
return highlighted_txt
|
||||
else:
|
||||
return None
|
||||
@ -1565,9 +1566,9 @@ class OBConnection(DocStoreConnection):
|
||||
if token_pos != -1:
|
||||
if token in keywords:
|
||||
highlighted_txt = (
|
||||
highlighted_txt[:token_pos] +
|
||||
f'<em>{token}</em>' +
|
||||
highlighted_txt[token_pos + len(token):]
|
||||
highlighted_txt[:token_pos] +
|
||||
f'<em>{token}</em>' +
|
||||
highlighted_txt[token_pos + len(token):]
|
||||
)
|
||||
last_pos = token_pos
|
||||
return re.sub(r'</em><em>', '', highlighted_txt)
|
||||
|
||||
@ -6,7 +6,6 @@ from urllib.parse import quote_plus
|
||||
from common.config_utils import get_base_config
|
||||
from common.decorator import singleton
|
||||
|
||||
|
||||
CREATE_TABLE_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS `{}` (
|
||||
`key` VARCHAR(255) PRIMARY KEY,
|
||||
@ -36,7 +35,8 @@ def get_opendal_config():
|
||||
"table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"),
|
||||
"max_allowed_packet": str(max_packet)
|
||||
}
|
||||
kwargs["connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
|
||||
kwargs[
|
||||
"connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
|
||||
else:
|
||||
scheme = opendal_config.get("scheme")
|
||||
config_data = opendal_config.get("config", {})
|
||||
@ -61,7 +61,7 @@ def get_opendal_config():
|
||||
del kwargs["password"]
|
||||
if "connection_string" in kwargs:
|
||||
del kwargs["connection_string"]
|
||||
return kwargs
|
||||
return kwargs
|
||||
except Exception as e:
|
||||
logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))
|
||||
raise
|
||||
@ -99,7 +99,6 @@ class OpenDALStorage:
|
||||
def obj_exist(self, bucket, fnm, tenant_id=None):
|
||||
return self._operator.exists(f"{bucket}/{fnm}")
|
||||
|
||||
|
||||
def init_db_config(self):
|
||||
try:
|
||||
conn = pymysql.connect(
|
||||
|
||||
@ -26,7 +26,8 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
|
||||
from opensearchpy import ConnectionTimeout
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
@ -189,7 +190,7 @@ class OSConnection(DocStoreConnection):
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
bqry.boost = 1.0 - vector_similarity_weight
|
||||
|
||||
|
||||
# Elasticsearch has the encapsulation of KNN_search in python sdk
|
||||
# while the Python SDK for OpenSearch does not provide encapsulation for KNN_search,
|
||||
# the following codes implement KNN_search in OpenSearch using DSL
|
||||
@ -216,7 +217,7 @@ class OSConnection(DocStoreConnection):
|
||||
if bqry:
|
||||
s = s.query(bqry)
|
||||
for field in highlightFields:
|
||||
s = s.highlight(field,force_source=True,no_match_size=30,require_field_match=False)
|
||||
s = s.highlight(field, force_source=True, no_match_size=30, require_field_match=False)
|
||||
|
||||
if orderBy:
|
||||
orders = list()
|
||||
@ -239,10 +240,10 @@ class OSConnection(DocStoreConnection):
|
||||
s = s[offset:offset + limit]
|
||||
q = s.to_dict()
|
||||
logger.debug(f"OSConnection.search {str(indexNames)} query: " + json.dumps(q))
|
||||
|
||||
|
||||
if use_knn:
|
||||
del q["query"]
|
||||
q["query"] = {"knn" : knn_query}
|
||||
q["query"] = {"knn": knn_query}
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
@ -328,7 +329,7 @@ class OSConnection(DocStoreConnection):
|
||||
chunkId = condition["id"]
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
self.os.update(index=indexName, id=chunkId, body={"doc":doc})
|
||||
self.os.update(index=indexName, id=chunkId, body={"doc": doc})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@ -435,7 +436,7 @@ class OSConnection(DocStoreConnection):
|
||||
logger.debug("OSConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
#print(Search().query(qry).to_dict(), flush=True)
|
||||
# print(Search().query(qry).to_dict(), flush=True)
|
||||
res = self.os.delete_by_query(
|
||||
index=indexName,
|
||||
body=Search().query(qry).to_dict(),
|
||||
|
||||
@ -42,14 +42,16 @@ class RAGFlowOSS:
|
||||
# If there is a default bucket, use the default bucket
|
||||
actual_bucket = self.bucket if self.bucket else bucket
|
||||
return method(self, actual_bucket, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@staticmethod
|
||||
def use_prefix_path(method):
|
||||
def wrapper(self, bucket, fnm, *args, **kwargs):
|
||||
# If the prefix path is set, use the prefix path
|
||||
fnm = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm
|
||||
return method(self, bucket, fnm, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def __open__(self):
|
||||
@ -171,4 +173,3 @@ class RAGFlowOSS:
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return None
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ Utility functions for Raptor processing decisions.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# File extensions for structured data types
|
||||
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
|
||||
CSV_EXTENSIONS = {".csv", ".tsv"}
|
||||
@ -40,12 +39,12 @@ def is_structured_file_type(file_type: Optional[str]) -> bool:
|
||||
"""
|
||||
if not file_type:
|
||||
return False
|
||||
|
||||
|
||||
# Normalize to lowercase and ensure leading dot
|
||||
file_type = file_type.lower()
|
||||
if not file_type.startswith("."):
|
||||
file_type = f".{file_type}"
|
||||
|
||||
|
||||
return file_type in STRUCTURED_EXTENSIONS
|
||||
|
||||
|
||||
@ -61,23 +60,23 @@ def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) ->
|
||||
True if PDF is being parsed as tabular data
|
||||
"""
|
||||
parser_config = parser_config or {}
|
||||
|
||||
|
||||
# If using table parser, it's tabular
|
||||
if parser_id and parser_id.lower() == "table":
|
||||
return True
|
||||
|
||||
|
||||
# Check if html4excel is enabled (Excel-like table parsing)
|
||||
if parser_config.get("html4excel", False):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_skip_raptor(
|
||||
file_type: Optional[str] = None,
|
||||
parser_id: str = "",
|
||||
parser_config: Optional[dict] = None,
|
||||
raptor_config: Optional[dict] = None
|
||||
file_type: Optional[str] = None,
|
||||
parser_id: str = "",
|
||||
parser_config: Optional[dict] = None,
|
||||
raptor_config: Optional[dict] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if Raptor should be skipped for a given document.
|
||||
@ -97,30 +96,30 @@ def should_skip_raptor(
|
||||
"""
|
||||
parser_config = parser_config or {}
|
||||
raptor_config = raptor_config or {}
|
||||
|
||||
|
||||
# Check if auto-disable is explicitly disabled in config
|
||||
if raptor_config.get("auto_disable_for_structured_data", True) is False:
|
||||
logging.info("Raptor auto-disable is turned off via configuration")
|
||||
return False
|
||||
|
||||
|
||||
# Check for Excel/CSV files
|
||||
if is_structured_file_type(file_type):
|
||||
logging.info(f"Skipping Raptor for structured file type: {file_type}")
|
||||
return True
|
||||
|
||||
|
||||
# Check for tabular PDFs
|
||||
if file_type and file_type.lower() in [".pdf", "pdf"]:
|
||||
if is_tabular_pdf(parser_id, parser_config):
|
||||
logging.info(f"Skipping Raptor for tabular PDF (parser_id={parser_id})")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_skip_reason(
|
||||
file_type: Optional[str] = None,
|
||||
parser_id: str = "",
|
||||
parser_config: Optional[dict] = None
|
||||
file_type: Optional[str] = None,
|
||||
parser_id: str = "",
|
||||
parser_config: Optional[dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get a human-readable reason why Raptor was skipped.
|
||||
@ -134,12 +133,12 @@ def get_skip_reason(
|
||||
Reason string, or empty string if Raptor should not be skipped
|
||||
"""
|
||||
parser_config = parser_config or {}
|
||||
|
||||
|
||||
if is_structured_file_type(file_type):
|
||||
return f"Structured data file ({file_type}) - Raptor auto-disabled"
|
||||
|
||||
|
||||
if file_type and file_type.lower() in [".pdf", "pdf"]:
|
||||
if is_tabular_pdf(parser_id, parser_config):
|
||||
return f"Tabular PDF (parser={parser_id}) - Raptor auto-disabled"
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
@ -33,6 +33,7 @@ except Exception:
|
||||
except Exception:
|
||||
REDIS = {}
|
||||
|
||||
|
||||
class RedisMsg:
|
||||
def __init__(self, consumer, queue_name, group_name, msg_id, message):
|
||||
self.__consumer = consumer
|
||||
@ -278,7 +279,8 @@ class RedisDB:
|
||||
def decrby(self, key: str, decrement: int):
|
||||
return self.REDIS.decrby(key, decrement)
|
||||
|
||||
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
|
||||
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default",
|
||||
increment: int = 1, ensure_minimum: int | None = None) -> int:
|
||||
redis_key = f"{key_prefix}:{namespace}"
|
||||
|
||||
try:
|
||||
|
||||
@ -46,6 +46,7 @@ class RAGFlowS3:
|
||||
# If there is a default bucket, use the default bucket
|
||||
actual_bucket = self.bucket if self.bucket else bucket
|
||||
return method(self, actual_bucket, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
@ -57,6 +58,7 @@ class RAGFlowS3:
|
||||
if self.prefix_path:
|
||||
fnm = f"{self.prefix_path}/{bucket}/{fnm}"
|
||||
return method(self, bucket, fnm, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def __open__(self):
|
||||
@ -81,16 +83,16 @@ class RAGFlowS3:
|
||||
s3_params['region_name'] = self.region_name
|
||||
if self.endpoint_url:
|
||||
s3_params['endpoint_url'] = self.endpoint_url
|
||||
|
||||
|
||||
# Configure signature_version and addressing_style through Config object
|
||||
if self.signature_version:
|
||||
config_kwargs['signature_version'] = self.signature_version
|
||||
if self.addressing_style:
|
||||
config_kwargs['s3'] = {'addressing_style': self.addressing_style}
|
||||
|
||||
|
||||
if config_kwargs:
|
||||
s3_params['config'] = Config(**config_kwargs)
|
||||
|
||||
|
||||
self.conn = [boto3.client('s3', **s3_params)]
|
||||
except Exception:
|
||||
logging.exception(f"Fail to connect at region {self.region_name} or endpoint {self.endpoint_url}")
|
||||
@ -184,9 +186,9 @@ class RAGFlowS3:
|
||||
for _ in range(10):
|
||||
try:
|
||||
r = self.conn[0].generate_presigned_url('get_object',
|
||||
Params={'Bucket': bucket,
|
||||
'Key': fnm},
|
||||
ExpiresIn=expires)
|
||||
Params={'Bucket': bucket,
|
||||
'Key': fnm},
|
||||
ExpiresIn=expires)
|
||||
|
||||
return r
|
||||
except Exception:
|
||||
|
||||
@ -30,7 +30,8 @@ class Tavily:
|
||||
search_depth="advanced",
|
||||
max_results=6
|
||||
)
|
||||
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]]
|
||||
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res
|
||||
in response["results"]]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -64,5 +65,5 @@ class Tavily:
|
||||
"count": 1,
|
||||
"url": r["url"]
|
||||
})
|
||||
logging.info("[Tavily]R: "+r["content"][:128]+"...")
|
||||
return {"chunks": chunks, "doc_aggs": aggs}
|
||||
logging.info("[Tavily]R: " + r["content"][:128] + "...")
|
||||
return {"chunks": chunks, "doc_aggs": aggs}
|
||||
|
||||
Reference in New Issue
Block a user