Fix IDE warnings (#12281)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-12-29 12:01:18 +08:00
committed by GitHub
parent 647fb115a0
commit 01f0ced1e6
43 changed files with 817 additions and 637 deletions

View File

@ -34,7 +34,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
if not ext: if not ext:
raise RuntimeError("No extension detected.") raise RuntimeError("No extension detected.")
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]: if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma",
".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
raise RuntimeError(f"Extension {ext} is not supported yet.") raise RuntimeError(f"Extension {ext} is not supported yet.")
tmp_path = "" tmp_path = ""

View File

@ -93,7 +93,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
random_choices([t for t, _ in sections], k=200))) random_choices([t for t, _ in sections], k=200)))
tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs) tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs)
# tbls = [((None, lns), None) for lns in tbls] # tbls = [((None, lns), None) for lns in tbls]
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)] sections = [(item[0], item[1] if item[1] is not None else "") for item in sections if
not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
@ -199,6 +200,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)

View File

@ -93,7 +93,8 @@ def chunk(
_add_content(msg, msg.get_content_type()) _add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [ sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line (line, "") for line in
HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
] ]
st = timer() st = timer()
@ -126,7 +127,9 @@ def chunk(
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -29,8 +29,6 @@ from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
class Docx(DocxParser): class Docx(DocxParser):
def __init__(self): def __init__(self):
pass pass
@ -89,7 +87,6 @@ class Docx(DocxParser):
return [element for element in root.get_tree() if element] return [element for element in root.get_tree() if element]
def __str__(self) -> str: def __str__(self) -> str:
return f''' return f'''
question:{self.question}, question:{self.question},
@ -121,8 +118,7 @@ class Pdf(PdfParser):
start = timer() start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start)) callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts:".format( logging.debug("layouts: {}".format((timer() - start)))
))
self._naive_vertical_merge() self._naive_vertical_merge()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start)) callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
@ -226,7 +222,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
raise NotImplementedError( raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)") "file type not supported yet(doc, docx, pdf, txt supported)")
# Remove 'Contents' part # Remove 'Contents' part
remove_contents_table(sections, eng) remove_contents_table(sections, eng)
@ -234,7 +229,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
bull = bullets_category(sections) bull = bullets_category(sections)
res = tree_merge(bull, sections, 2) res = tree_merge(bull, sections, 2)
if not res: if not res:
callback(0.99, "No chunk parsed out.") callback(0.99, "No chunk parsed out.")
@ -243,9 +237,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
# chunks = hierarchical_merge(bull, sections, 5) # chunks = hierarchical_merge(bull, sections, 5)
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser) # return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -20,7 +20,8 @@ import re
from common.constants import ParserType from common.constants import ParserType
from io import BytesIO from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level, attach_media_context from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, \
docx_question_level, attach_media_context
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
from deepdoc.parser import PdfParser, DocxParser from deepdoc.parser import PdfParser, DocxParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper, vision_figure_parser_docx_wrapper from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper, vision_figure_parser_docx_wrapper
@ -29,6 +30,7 @@ from PIL import Image
from rag.app.naive import by_plaintext, PARSERS from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
class Pdf(PdfParser): class Pdf(PdfParser):
def __init__(self): def __init__(self):
self.model_speciess = ParserType.MANUAL.value self.model_speciess = ParserType.MANUAL.value

View File

@ -31,16 +31,20 @@ from common.token_utils import num_tokens_from_string
from common.constants import LLMType from common.constants import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, \
from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper PdfParser, TxtParser
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_docx_wrapper, \
vision_figure_parser_pdf_wrapper
from deepdoc.parser.pdf_parser import PlainParser, VisionParser from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from deepdoc.parser.docling_parser import DoclingParser from deepdoc.parser.docling_parser import DoclingParser
from deepdoc.parser.tcadp_parser import TCADPParser from deepdoc.parser.tcadp_parser import TCADPParser
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, \
tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs): def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
**kwargs):
callback = callback callback = callback
binary = binary binary = binary
pdf_parser = pdf_cls() if pdf_cls else Pdf() pdf_parser = pdf_cls() if pdf_cls else Pdf()
@ -106,7 +110,8 @@ def by_mineru(
return None, None, None return None, None, None
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs): def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
**kwargs):
pdf_parser = DoclingParser() pdf_parser = DoclingParser()
parse_method = kwargs.get("parse_method", "raw") parse_method = kwargs.get("parse_method", "raw")
@ -426,7 +431,8 @@ class Docx(DocxParser):
try: try:
if inline_images: if inline_images:
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64)) result = mammoth.convert_to_html(docx_file,
convert_image=mammoth.images.img_element(_convert_image_to_base64))
else: else:
result = mammoth.convert_to_html(docx_file) result = mammoth.convert_to_html(docx_file)
@ -621,6 +627,7 @@ class Markdown(MarkdownParser):
return sections, tbls, section_images return sections, tbls, section_images
return sections, tbls return sections, tbls
def load_from_xml_v2(baseURI, rels_item_xml): def load_from_xml_v2(baseURI, rels_item_xml):
""" """
Return |_SerializedRelationships| instance loaded with the Return |_SerializedRelationships| instance loaded with the
@ -636,6 +643,7 @@ def load_from_xml_v2(baseURI, rels_item_xml):
srels._srels.append(_SerializedRelationship(baseURI, rel_elm)) srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
return srels return srels
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, excel, txt. Supported file formats are docx, pdf, excel, txt.
@ -651,7 +659,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
"parser_config", { "parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True}) "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8') child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode(
'latin1').decode('utf-8')
cust_child_deli = re.findall(r"`([^`]+)`", child_deli) cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli)) child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
if cust_child_deli: if cust_child_deli:
@ -685,7 +694,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
# Recursively chunk each embedded file and collect results # Recursively chunk each embedded file and collect results
for embed_filename, embed_bytes in embeds: for embed_filename, embed_bytes in embeds:
try: try:
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False, **kwargs) or [] sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False,
**kwargs) or []
embed_res.extend(sub_res) embed_res.extend(sub_res)
except Exception as e: except Exception as e:
if callback: if callback:
@ -704,7 +714,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs) sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
except Exception as e: except Exception as e:
logging.info(f"Failed to chunk url in registered file type {url}: {e}") logging.info(f"Failed to chunk url in registered file type {url}: {e}")
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs) sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False,
**kwargs)
url_res.extend(sub_url_res) url_res.extend(sub_url_res)
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246 # fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
@ -846,9 +857,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
else: else:
section_images = [None] * len(sections) section_images = [None] * len(sections)
section_images[idx] = combined_image section_images[idx] = combined_image
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs) markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=[
((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
boosted_figures = markdown_vision_parser(callback=callback) boosted_figures = markdown_vision_parser(callback=callback)
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]), sections[idx][1]) sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]),
sections[idx][1])
else: else:
logging.warning("No visual model detected. Skipping figure parsing enhancement.") logging.warning("No visual model detected. Skipping figure parsing enhancement.")
@ -945,7 +958,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
has_images = merged_images and any(img is not None for img in merged_images) has_images = merged_images and any(img is not None for img in merged_images)
if has_images: if has_images:
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images, child_delimiters_pattern=child_deli)) res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images,
child_delimiters_pattern=child_deli))
else: else:
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser, child_delimiters_pattern=child_deli)) res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser, child_delimiters_pattern=child_deli))
else: else:
@ -958,7 +972,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
int(parser_config.get( int(parser_config.get(
"chunk_token_num", 128)), parser_config.get( "chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?")) "delimiter", "\n!?。;!?"))
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli)) res.extend(
tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
else: else:
chunks = naive_merge( chunks = naive_merge(
sections, int(parser_config.get( sections, int(parser_config.get(
@ -993,7 +1008,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -26,6 +26,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from rag.app.naive import by_plaintext, PARSERS from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
class Pdf(PdfParser): class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0, def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None): to_page=100000, zoomin=3, callback=None):
@ -172,7 +173,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -20,7 +20,8 @@ import re
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
from common.constants import ParserType from common.constants import ParserType
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks, attach_media_context from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, \
tokenize_chunks, attach_media_context
from deepdoc.parser import PdfParser from deepdoc.parser import PdfParser
import numpy as np import numpy as np
from rag.app.naive import by_plaintext, PARSERS from rag.app.naive import by_plaintext, PARSERS
@ -329,6 +330,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -51,7 +51,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
} }
) )
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang) cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)) ans = asyncio.run(
cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
callback(0.8, "CV LLM respond: %s ..." % ans[:32]) callback(0.8, "CV LLM respond: %s ..." % ans[:32])
ans += "\n" + ans ans += "\n" + ans
tokenize(doc, ans, eng) tokenize(doc, ans, eng)

View File

@ -249,7 +249,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(a, b): def dummy(a, b):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -116,10 +116,12 @@ class Pdf(PdfParser):
last_index = -1 last_index = -1
last_box = {'text': ''} last_box = {'text': ''}
last_bull = None last_bull = None
def sort_key(element): def sort_key(element):
tbls_pn = element[1][0][0] tbls_pn = element[1][0][0]
tbls_top = element[1][0][3] tbls_top = element[1][0][3]
return tbls_pn, tbls_top return tbls_pn, tbls_top
tbls.sort(key=sort_key) tbls.sort(key=sort_key)
tbl_index = 0 tbl_index = 0
last_pn, last_bottom = 0, 0 last_pn, last_bottom = 0, 0
@ -140,21 +142,25 @@ class Pdf(PdfParser):
sum_tag = line_tag sum_tag = line_tag
sum_section = section sum_section = section
while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \ while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer and ((tbl_pn == line_pn and tbl_top <= line_top) or (
tbl_pn < line_pn)): # add image at the middle of current answer
sum_tag = f'{tbl_tag}{sum_tag}' sum_tag = f'{tbl_tag}{sum_tag}'
sum_section = f'{tbl_text}{sum_section}' sum_section = f'{tbl_text}{sum_section}'
tbl_index += 1 tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
tbl_index)
last_a = f'{last_a}{sum_section}' last_a = f'{last_a}{sum_section}'
last_tag = f'{last_tag}{sum_tag}' last_tag = f'{last_tag}{sum_tag}'
else: else:
if last_q: if last_q:
while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \ while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer and ((tbl_pn == line_pn and tbl_top <= line_top) or (
tbl_pn < line_pn)): # add image at the end of last answer
last_tag = f'{last_tag}{tbl_tag}' last_tag = f'{last_tag}{tbl_tag}'
last_a = f'{last_a}{tbl_text}' last_a = f'{last_a}{tbl_text}'
tbl_index += 1 tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
tbl_index)
image, poss = self.crop(last_tag, need_position=True) image, poss = self.crop(last_tag, need_position=True)
qai_list.append((last_q, last_a, image, poss)) qai_list.append((last_q, last_a, image, poss))
last_q, last_a, last_tag = '', '', '' last_q, last_a, last_tag = '', '', ''
@ -435,7 +441,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if last_answer.strip(): if last_answer.strip():
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) res.append(beAdoc(deepcopy(doc), sum_question,
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = '' last_answer = ''
i = question_level i = question_level
@ -447,7 +454,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if last_answer.strip(): if last_answer.strip():
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) res.append(beAdoc(deepcopy(doc), sum_question,
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res return res
elif re.search(r"\.docx$", filename, re.IGNORECASE): elif re.search(r"\.docx$", filename, re.IGNORECASE):
@ -466,6 +474,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -64,7 +64,8 @@ def remote_call(filename, binary):
del resume[k] del resume[k]
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) "updated_at": datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume) resume = step_two.parse(resume)
return resume return resume
except Exception: except Exception:
@ -171,6 +172,9 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(a, b): def dummy(a, b):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -53,7 +53,8 @@ class Excel(ExcelParser):
ws = wb[sheetname] ws = wb[sheetname]
images = Excel._extract_images_from_worksheet(ws, sheetname=sheetname) images = Excel._extract_images_from_worksheet(ws, sheetname=sheetname)
if images: if images:
image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, **kwargs) image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback,
**kwargs)
if image_descriptions and len(image_descriptions) == len(images): if image_descriptions and len(image_descriptions) == len(images):
for i, bf in enumerate(image_descriptions): for i, bf in enumerate(image_descriptions):
images[i]["image_description"] = "\n".join(bf[0][1]) images[i]["image_description"] = "\n".join(bf[0][1])
@ -121,7 +122,8 @@ class Excel(ExcelParser):
] ]
) )
) )
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res, tables return res, tables
def _parse_headers(self, ws, rows): def _parse_headers(self, ws, rows):
@ -315,7 +317,8 @@ def trans_bool(s):
def column_data_type(arr): def column_data_type(arr):
arr = list(arr) arr = list(arr)
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
float_flag = False float_flag = False
for a in arr: for a in arr:
if a is None: if a is None:
@ -389,7 +392,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
continue continue
rows.append(row) rows.append(row)
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)] dfs = [pd.DataFrame(np.array(rows), columns=headers)]
elif re.search(r"\.csv$", filename, re.IGNORECASE): elif re.search(r"\.csv$", filename, re.IGNORECASE):
@ -445,7 +449,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
df[clmns[j]] = cln df[clmns[j]] = cln
if ty == "text": if ty == "text":
txts.extend([str(c) for c in cln if c]) txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))] clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
range(len(clmns))]
eng = lang.lower() == "english" # is_english(txts) eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows(): for ii, row in df.iterrows():
@ -477,7 +482,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -152,6 +152,9 @@ def label_question(question, kbs):
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -263,7 +263,7 @@ class SparkTTS(Base):
raise Exception(error) raise Exception(error)
def on_close(self, ws, close_status_code, close_msg): def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志 self.audio_queue.put(None) # None is terminator
def on_open(self, ws): def on_open(self, ws):
def run(*args): def run(*args):

View File

@ -658,7 +658,8 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
if "content_ltks" in ck: if "content_ltks" in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(combined) ck["content_ltks"] = rag_tokenizer.tokenize(combined)
if "content_sm_ltks" in ck: if "content_sm_ltks" in ck:
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck.get("content_ltks", rag_tokenizer.tokenize(combined))) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
if positioned_indices: if positioned_indices:
chunks[:] = [chunks[i] for i in ordered_indices] chunks[:] = [chunks[i] for i in ordered_indices]
@ -764,8 +765,8 @@ def not_title(txt):
return True return True
return re.search(r"[,;,。;!!]", txt) return re.search(r"[,;,。;!!]", txt)
def tree_merge(bull, sections, depth):
def tree_merge(bull, sections, depth):
if not sections or bull < 0: if not sections or bull < 0:
return sections return sections
if isinstance(sections[0], type("")): if isinstance(sections[0], type("")):
@ -787,6 +788,7 @@ def tree_merge(bull, sections, depth):
return len(BULLET_PATTERN[bull]) + 1, text return len(BULLET_PATTERN[bull]) + 1, text
else: else:
return len(BULLET_PATTERN[bull]) + 2, text return len(BULLET_PATTERN[bull]) + 2, text
level_set = set() level_set = set()
lines = [] lines = []
for section in sections: for section in sections:
@ -812,8 +814,8 @@ def tree_merge(bull, sections, depth):
return [element for element in root.get_tree() if element] return [element for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0: if not sections or bull < 0:
return [] return []
if isinstance(sections[0], type("")): if isinstance(sections[0], type("")):

View File

@ -232,4 +232,5 @@ class FulltextQueryer(QueryBase):
keywords.append(f"{tk}^{w}") keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100, return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) / 10), "original_query": " ".join(origin_keywords)}) {"minimum_should_match": min(3, len(keywords) / 10),
"original_query": " ".join(origin_keywords)})

View File

@ -66,7 +66,8 @@ class Dealer:
if key in req and req[key] is not None: if key in req and req[key] is not None:
condition[field] = req[key] condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns. # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]: for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
"removed_kwd"]:
if key in req and req[key] is not None: if key in req and req[key] is not None:
condition[key] = req[key] condition[key] = req[key]
return condition return condition
@ -141,7 +142,8 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1) matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17 matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.get_total(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total)) logging.debug("Dealer.search 2 TOTAL: {}".format(total))
@ -219,7 +221,8 @@ class Dealer:
for i in range(len(chunk_v)): for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]): if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0] * len(ans_v[0]) chunk_v[i] = [0.0] * len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i]))) logging.warning(
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0])) len(ans_v[0]), len(chunk_v[0]))
@ -395,7 +398,8 @@ class Dealer:
if isinstance(tenant_ids, str): if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",") tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature) sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
rank_feature=rank_feature)
if rerank_mdl and sres.total > 0: if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model( sim, tsim, vsim = self.rerank_by_model(
@ -558,7 +562,8 @@ class Dealer:
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000): def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id) idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.get_aggregation(res, "tag_kwd") aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs: if not aggs:
@ -596,7 +601,8 @@ class Dealer:
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"] doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0] doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
kb_ids = [doc_id2kb_id[doc_id]] kb_ids = [doc_id2kb_id[doc_id]]
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
OrderByExpr(), 0, 128, idx_nms,
kb_ids) kb_ids)
toc = [] toc = []
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"]) dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])

View File

@ -116,7 +116,9 @@ m = set(["赵","钱","孙","李",
"", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "",
"","","","西","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","鹿","", "", "", "", "西", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "鹿", "",
"万俟", "司马", "上官", "欧阳", "万俟", "司马", "上官", "欧阳",
"夏侯", "诸葛", "闻人", "东方", "夏侯", "诸葛", "闻人", "东方",
"赫连", "皇甫", "尉迟", "公羊", "赫连", "皇甫", "尉迟", "公羊",
@ -138,5 +140,5 @@ m = set(["赵","钱","孙","李",
"", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
"第五", "", ""]) "第五", "", ""])
def isit(n):return n.strip() in m
def isit(n): return n.strip() in m

View File

@ -114,7 +114,8 @@ class Dealer:
return res return res
def token_merge(self, tks): def token_merge(self, tks):
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) def one_term(t):
return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0 res, i = [], 0
while i < len(tks): while i < len(tks):
@ -220,7 +221,8 @@ class Dealer:
return 3 return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5))) def idf(s, N):
return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = [] tw = []
if not preprocess: if not preprocess:

View File

@ -28,17 +28,16 @@ from rag.prompts.template import load_prompt
from common.constants import TAG_FLD from common.constants import TAG_FLD
from common.token_utils import encoder, num_tokens_from_string from common.token_utils import encoder, num_tokens_from_string
STOP_TOKEN = "<|STOP|>" STOP_TOKEN = "<|STOP|>"
COMPLETE_TASK = "complete_task" COMPLETE_TASK = "complete_task"
INPUT_UTILIZATION = 0.5 INPUT_UTILIZATION = 0.5
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
def chunks_format(reference): def chunks_format(reference):
return [ return [
{ {
"id": get_value(chunk, "chunk_id", "id"), "id": get_value(chunk, "chunk_id", "id"),
@ -258,9 +257,11 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,
languages=languages)
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}],
{"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
return query return query
@ -332,7 +333,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
"description": "When you have the final answer and are ready to complete the task, call this function with your answer", "description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}}, "properties": {
"answer": {"type": "string", "description": "The final answer to the user's question"}},
"required": ["answer"] "required": ["answer"]
} }
} }
@ -341,7 +343,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
name = tool["function"]["name"] name = tool["function"]["name"]
desc[name] = tool desc[name] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())]) return "\n\n".join([f"## {i + 1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in
enumerate(desc.items())])
def form_history(history, limit=-6): def form_history(history, limit=-6):
@ -356,8 +359,8 @@ def form_history(history, limit=-6):
return context return context
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict],
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}): user_defined_prompts: dict = {}):
tools_desc = tool_schema(tools_description) tools_desc = tool_schema(tools_description)
context = "" context = ""
@ -375,7 +378,8 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd return kwd
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): async def next_step_async(chat_mdl, history: list, tools_description: list[dict], task_desc,
user_defined_prompts: dict = {}):
if not tools_description: if not tools_description:
return "", 0 return "", 0
desc = tool_schema(tools_description) desc = tool_schema(tools_description)
@ -438,9 +442,11 @@ async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}): async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summaries: list[str],
user_defined_prompts: dict = {}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY) template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) system_prompt = template.render(goal=goal, sub_goal=sub_goal,
results=[{"i": i, "content": s} for i, s in enumerate(tool_call_summaries)])
user_prompt = " → rank: " user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>") ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
@ -488,10 +494,13 @@ async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None
TOC_DETECTION = load_prompt("toc_detection") TOC_DETECTION = load_prompt("toc_detection")
async def detect_table_of_contents(page_1024: list[str], chat_mdl): async def detect_table_of_contents(page_1024: list[str], chat_mdl):
toc_secs = [] toc_secs = []
for i, sec in enumerate(page_1024[:22]): for i, sec in enumerate(page_1024[:22]):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.",
chat_mdl)
if toc_secs and not ans["exists"]: if toc_secs and not ans["exists"]:
break break
toc_secs.append(sec) toc_secs.append(sec)
@ -500,11 +509,14 @@ async def detect_table_of_contents(page_1024:list[str], chat_mdl):
TOC_EXTRACTION = load_prompt("toc_extraction") TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue") TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
async def extract_table_of_contents(toc_pages, chat_mdl): async def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages: if not toc_pages:
return [] return []
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)),
"Only JSON please.", chat_mdl)
async def toc_index_extractor(toc: list[dict], content: str, chat_mdl): async def toc_index_extractor(toc: list[dict], content: str, chat_mdl):
@ -529,11 +541,14 @@ async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
If the title of the section are not in the provided pages, do not add the physical_index to it. If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else.""" Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False,
indent=2) + '\nDocument pages:\n' + content
return await gen_json(prompt, "Only JSON please.", chat_mdl) return await gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index") TOC_INDEX = load_prompt("toc_index")
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections: if not toc_arr or not sections:
return [] return []
@ -558,6 +573,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
toc_arr[j]["indices"].append(i) toc_arr[j]["indices"].append(i)
all_pathes = [] all_pathes = []
def dfs(start, path): def dfs(start, path):
nonlocal all_pathes nonlocal all_pathes
if start >= len(toc_arr): if start >= len(toc_arr):
@ -656,11 +672,15 @@ async def toc_transformer(toc_pages, chat_mdl):
toc_content = "\n".join(toc_pages) toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content prompt = init_prompt + '\n Given table of contents\n:' + toc_content
def clean_toc(arr): def clean_toc(arr):
for a in arr: for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"]) a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl) last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False, indent=2),
chat_mdl)
clean_toc(last_complete) clean_toc(last_complete)
if if_complete == "yes": if if_complete == "yes":
return last_complete return last_complete
@ -682,12 +702,16 @@ async def toc_transformer(toc_pages, chat_mdl):
break break
clean_toc(new_complete) clean_toc(new_complete)
last_complete.extend(new_complete) last_complete.extend(new_complete)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False,
indent=2), chat_mdl)
return last_complete return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels") TOC_LEVELS = load_prompt("assign_toc_levels")
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}): async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}):
if not toc_secs: if not toc_secs:
return [] return []
@ -701,12 +725,15 @@ async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2})
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system") TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user") TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms # Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None): async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try: try:
ans = await gen_json( ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(
text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl, chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9} gen_conf={"temperature": 0.0, "top_p": 0.9}
) )
@ -812,12 +839,15 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system") TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user") TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6): async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
import numpy as np import numpy as np
try: try:
ans = await gen_json( ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n" % "\n".join(
[json.dumps({"level": d["level"], "title": d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl, chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9} gen_conf={"temperature": 0.0, "top_p": 0.9}
) )
@ -838,6 +868,8 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
META_DATA = load_prompt("meta_data") META_DATA = load_prompt("meta_data")
async def gen_metadata(chat_mdl, schema: dict, content: str): async def gen_metadata(chat_mdl, schema: dict, content: str):
template = PROMPT_JINJA_ENV.from_string(META_DATA) template = PROMPT_JINJA_ENV.from_string(META_DATA)
for k, desc in schema["properties"].items(): for k, desc in schema["properties"].items():

View File

@ -1,6 +1,5 @@
import os import os
PROMPT_DIR = os.path.dirname(__file__) PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {} _loaded_prompts = {}

View File

@ -48,9 +48,11 @@ def main():
REDIS_CONN.transaction(key, file_bin, 12 * 60) REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc)) logging.info("CACHE: {}".format(loc))
except Exception as e: except Exception as e:
traceback.print_stack(e) logging.error(f"Error to get data from REDIS: {e}")
traceback.print_stack()
except Exception as e: except Exception as e:
traceback.print_stack(e) logging.error(f"Error to check REDIS connection: {e}")
traceback.print_stack()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -29,7 +29,6 @@ JSON_DATA = {
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application
intents = discord.Intents.default() intents = discord.Intents.default()
intents.message_content = True intents.message_content = True
client = discord.Client(intents=intents) client = discord.Client(intents=intents)

View File

@ -38,7 +38,8 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings from common import settings
from common.config_utils import show_configs from common.config_utils import show_configs
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, \
MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
from common.constants import FileSource, TaskStatus from common.constants import FileSource, TaskStatus
from common.data_source.config import INDEX_BATCH_SIZE from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.confluence_connector import ConfluenceConnector from common.data_source.confluence_connector import ConfluenceConnector
@ -161,6 +162,7 @@ class SyncBase:
def _get_source_prefix(self): def _get_source_prefix(self):
return "" return ""
class _BlobLikeBase(SyncBase): class _BlobLikeBase(SyncBase):
DEFAULT_BUCKET_TYPE: str = "s3" DEFAULT_BUCKET_TYPE: str = "s3"
@ -199,22 +201,27 @@ class _BlobLikeBase(SyncBase):
) )
return document_batch_generator return document_batch_generator
class S3(_BlobLikeBase): class S3(_BlobLikeBase):
SOURCE_NAME: str = FileSource.S3 SOURCE_NAME: str = FileSource.S3
DEFAULT_BUCKET_TYPE: str = "s3" DEFAULT_BUCKET_TYPE: str = "s3"
class R2(_BlobLikeBase): class R2(_BlobLikeBase):
SOURCE_NAME: str = FileSource.R2 SOURCE_NAME: str = FileSource.R2
DEFAULT_BUCKET_TYPE: str = "r2" DEFAULT_BUCKET_TYPE: str = "r2"
class OCI_STORAGE(_BlobLikeBase): class OCI_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.OCI_STORAGE SOURCE_NAME: str = FileSource.OCI_STORAGE
DEFAULT_BUCKET_TYPE: str = "oci_storage" DEFAULT_BUCKET_TYPE: str = "oci_storage"
class GOOGLE_CLOUD_STORAGE(_BlobLikeBase): class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage" DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
class Confluence(SyncBase): class Confluence(SyncBase):
SOURCE_NAME: str = FileSource.CONFLUENCE SOURCE_NAME: str = FileSource.CONFLUENCE
@ -248,7 +255,9 @@ class Confluence(SyncBase):
index_recursively=index_recursively, index_recursively=index_recursively,
) )
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"]) credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json=self.conf["credentials"])
self.connector.set_credentials_provider(credentials_provider) self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start # Determine the time range for synchronization based on reindex or poll_range_start
@ -280,7 +289,8 @@ class Confluence(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator: for document, failure, next_checkpoint in doc_generator:
if failure is not None: if failure is not None:
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure)) logging.warning("Confluence connector failure: %s",
getattr(failure, "failure_message", failure))
continue continue
if document is not None: if document is not None:
pending_docs.append(document) pending_docs.append(document)
@ -314,10 +324,12 @@ class Notion(SyncBase):
document_generator = ( document_generator = (
self.connector.load_from_state() self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"] if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
) )
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info)) logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
return document_generator return document_generator
@ -340,10 +352,12 @@ class Discord(SyncBase):
document_generator = ( document_generator = (
self.connector.load_from_state() self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"] if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
) )
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info)) logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
return document_generator return document_generator
@ -485,7 +499,8 @@ class GoogleDrive(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator: for document, failure, next_checkpoint in doc_generator:
if failure is not None: if failure is not None:
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure)) logging.warning("Google Drive connector failure: %s",
getattr(failure, "failure_message", failure))
continue continue
if document is not None: if document is not None:
pending_docs.append(document) pending_docs.append(document)
@ -667,6 +682,7 @@ class WebDAV(SyncBase):
)) ))
return document_batch_generator return document_batch_generator
class Moodle(SyncBase): class Moodle(SyncBase):
SOURCE_NAME: str = FileSource.MOODLE SOURCE_NAME: str = FileSource.MOODLE
@ -739,6 +755,7 @@ class BOX(SyncBase):
logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info)) logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info))
return document_generator return document_generator
class Airtable(SyncBase): class Airtable(SyncBase):
SOURCE_NAME: str = FileSource.AIRTABLE SOURCE_NAME: str = FileSource.AIRTABLE
@ -784,6 +801,7 @@ class Airtable(SyncBase):
return document_generator return document_generator
func_factory = { func_factory = {
FileSource.S3: S3, FileSource.S3: S3,
FileSource.R2: R2, FileSource.R2: R2,

View File

@ -283,7 +283,8 @@ async def build_chunks(task, progress_callback):
try: try:
d = copy.deepcopy(document) d = copy.deepcopy(document)
d.update(chunk) d.update(chunk)
d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() d["id"] = xxhash.xxh64(
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp() d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"): if not d.get("image"):
@ -328,9 +329,11 @@ async def build_chunks(task, progress_callback):
d["important_kwd"] = cached.split(",") d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return return
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))) tasks.append(
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try: try:
await asyncio.gather(*tasks, return_exceptions=False) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e: except Exception as e:
@ -355,9 +358,11 @@ async def build_chunks(task, progress_callback):
if cached: if cached:
d["question_kwd"] = cached.split("\n") d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))) tasks.append(
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try: try:
await asyncio.gather(*tasks, return_exceptions=False) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e: except Exception as e:
@ -374,15 +379,18 @@ async def build_chunks(task, progress_callback):
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
async def gen_metadata_task(chat_mdl, d): async def gen_metadata_task(chat_mdl, d):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", task["parser_config"]["metadata"]) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
task["parser_config"]["metadata"])
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await gen_metadata(chat_mdl, cached = await gen_metadata(chat_mdl,
metadata_schema(task["parser_config"]["metadata"]), metadata_schema(task["parser_config"]["metadata"]),
d["content_with_weight"]) d["content_with_weight"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", task["parser_config"]["metadata"]) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
task["parser_config"]["metadata"])
if cached: if cached:
d["metadata_obj"] = cached d["metadata_obj"] = cached
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d))) tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
@ -430,7 +438,8 @@ async def build_chunks(task, progress_callback):
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return None return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else: else:
docs_to_tag.append(d) docs_to_tag.append(d)
@ -454,6 +463,7 @@ async def build_chunks(task, progress_callback):
if cached: if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached) d[TAG_FLD] = json.loads(cached)
tasks = [] tasks = []
for d in docs_to_tag: for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags))) tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
@ -477,7 +487,8 @@ def build_TOC(task, docs, progress_callback):
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
)) ))
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)) toc: list[dict] = asyncio.run(
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' ')) logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0 ii = 0
while ii < len(toc): while ii < len(toc):
@ -499,7 +510,8 @@ def build_TOC(task, docs, progress_callback):
d["toc_kwd"] = "toc" d["toc_kwd"] = "toc"
d["available_int"] = 0 d["available_int"] = 0
d["page_num_int"] = [100000000] d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() d["id"] = xxhash.xxh64(
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d return d
return None return None
@ -588,7 +600,8 @@ async def run_dataflow(task: dict):
return return
if not chunks: if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0) embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
@ -610,10 +623,12 @@ async def run_dataflow(task: dict):
e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id) embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60) @timeout(60)
def batch_encode(txts): def batch_encode(txts):
nonlocal embedding_model nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts]) return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([]) vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks] texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1) delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
@ -636,10 +651,10 @@ async def run_dataflow(task: dict):
ck["q_%d_vec" % len(v)] = v ck["q_%d_vec" % len(v)] = v
except Exception as e: except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}") set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
metadata = {} metadata = {}
for ck in chunks: for ck in chunks:
ck["doc_id"] = doc_id ck["doc_id"] = doc_id
@ -686,15 +701,19 @@ async def run_dataflow(task: dict):
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...") set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e: if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
time_cost = timer() - start_ts time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost) DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) task_time_cost)
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
dsl=str(pipeline))
@timeout(3600) @timeout(3600)
@ -792,19 +811,22 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mom_ck["available_int"] = 0 mom_ck["available_int"] = 0
flds = list(mom_ck.keys()) flds = list(mom_ck.keys())
for fld in flds: for fld in flds:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int"]: if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
"position_int"]:
del mom_ck[fld] del mom_ck[fld]
mothers.append(mom_ck) mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE): for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return False return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE): for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
@ -821,7 +843,8 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str) TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,) doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
search.index_name(task_tenant_id), task_dataset_id, )
tasks = [] tasks = []
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id))) tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
@ -972,7 +995,6 @@ async def do_handle_task(task):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration") progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return return
graphrag_conf = kb_parser_config.get("graphrag", {}) graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer() start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
@ -1084,8 +1106,8 @@ async def do_handle_task(task):
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled." f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled."
) )
async def handle_task():
async def handle_task():
global DONE_TASKS, FAILED_TASKS global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect() redis_msg, task = await collect()
if not task: if not task:
@ -1093,7 +1115,8 @@ async def handle_task():
return return
task_type = task["task_type"] task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try: try:
logging.info(f"handle_task begin for task {json.dumps(task)}") logging.info(f"handle_task begin for task {json.dumps(task)}")
@ -1119,7 +1142,9 @@ async def handle_task():
if task_type in ["graphrag", "raptor", "mindmap"]: if task_type in ["graphrag", "raptor", "mindmap"]:
task_document_ids = task["doc_ids"] task_document_ids = task["doc_ids"]
if not task.get("dataflow_id", ""): if not task.get("dataflow_id", ""):
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids) PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
task_type=pipeline_task_type,
fake_document_ids=task_document_ids)
redis_msg.ack() redis_msg.ack()
@ -1249,6 +1274,7 @@ async def main():
await asyncio.gather(report_task, return_exceptions=True) await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!") logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__": if __name__ == "__main__":
faulthandler.enable() faulthandler.enable()
init_root_logger(CONSUMER_NAME) init_root_logger(CONSUMER_NAME)

View File

@ -42,8 +42,10 @@ class RAGFlowAzureSpnBlob:
pass pass
try: try:
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA) credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id,
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials) client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name,
credential=credentials)
except Exception: except Exception:
logging.exception("Fail to connect %s" % self.account_url) logging.exception("Fail to connect %s" % self.account_url)

View File

@ -25,6 +25,7 @@ from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64) test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str = "imagetemps"): async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str = "imagetemps"):
import logging import logging
from io import BytesIO from io import BytesIO

View File

@ -16,6 +16,8 @@
import logging import logging
from common.crypto_utils import CryptoUtil from common.crypto_utils import CryptoUtil
# from common.decorator import singleton # from common.decorator import singleton
class EncryptedStorageWrapper: class EncryptedStorageWrapper:
@ -240,6 +242,7 @@ class EncryptedStorageWrapper:
self.encryption_enabled = False self.encryption_enabled = False
logging.info("Encryption disabled") logging.info("Encryption disabled")
# Create singleton wrapper function # Create singleton wrapper function
def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True): def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True):
""" """

View File

@ -32,7 +32,6 @@ ATTEMPT_TIME = 2
@singleton @singleton
class ESConnection(ESConnectionBase): class ESConnection(ESConnectionBase):
""" """
CRUD operations CRUD operations
""" """
@ -82,7 +81,8 @@ class ESConnection(ESConnectionBase):
vector_similarity_weight = 0.5 vector_similarity_weight = 0.5
for m in match_expressions: for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1], assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
match_expressions[1],
MatchDenseExpr) and isinstance( MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr) match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"] weights = m.fusion_params["weights"]
@ -220,13 +220,15 @@ class ESConnection(ESConnectionBase):
try: try:
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");") self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception: except Exception:
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try: try:
self.es.update(index=index_name, id=chunk_id, doc=doc) self.es.update(index=index_name, id=chunk_id, doc=doc)
return True return True
except Exception as e: except Exception as e:
self.logger.exception( self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(
e))
break break
return False return False

View File

@ -25,18 +25,23 @@ import PyPDF2
from docx import Document from docx import Document
import olefile import olefile
def _is_zip(h: bytes) -> bool: def _is_zip(h: bytes) -> bool:
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08") return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
def _is_pdf(h: bytes) -> bool: def _is_pdf(h: bytes) -> bool:
return h.startswith(b"%PDF-") return h.startswith(b"%PDF-")
def _is_ole(h: bytes) -> bool: def _is_ole(h: bytes) -> bool:
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1") return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
def _sha10(b: bytes) -> str: def _sha10(b: bytes) -> str:
return hashlib.sha256(b).hexdigest()[:10] return hashlib.sha256(b).hexdigest()[:10]
def _guess_ext(b: bytes) -> str: def _guess_ext(b: bytes) -> str:
h = b[:8] h = b[:8]
if _is_zip(h): if _is_zip(h):
@ -58,6 +63,7 @@ def _guess_ext(b: bytes) -> str:
return ".doc" return ".doc"
return ".bin" return ".bin"
# Try to extract the real embedded payload from OLE's Ole10Native # Try to extract the real embedded payload from OLE's Ole10Native
def _extract_ole10native_payload(data: bytes) -> bytes: def _extract_ole10native_payload(data: bytes) -> bytes:
try: try:
@ -82,6 +88,7 @@ def _extract_ole10native_payload(data: bytes) -> bytes:
pass pass
return data return data
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]: def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
""" """
Only extract the 'first layer' of embedding, returning raw (filename, bytes). Only extract the 'first layer' of embedding, returning raw (filename, bytes).
@ -198,6 +205,8 @@ def extract_links_from_pdf(pdf_bytes: bytes):
_GLOBAL_SESSION: Optional[requests.Session] = None _GLOBAL_SESSION: Optional[requests.Session] = None
def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session: def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
"""Get or create a global reusable session.""" """Get or create a global reusable session."""
global _GLOBAL_SESSION global _GLOBAL_SESSION

View File

@ -28,7 +28,6 @@ from common.doc_store.infinity_conn_base import InfinityConnectionBase
@singleton @singleton
class InfinityConnection(InfinityConnectionBase): class InfinityConnection(InfinityConnectionBase):
""" """
Dataframe and fields convert Dataframe and fields convert
""" """
@ -83,7 +82,6 @@ class InfinityConnection(InfinityConnectionBase):
tokens[0] = field tokens[0] = field
return "^".join(tokens) return "^".join(tokens)
""" """
CRUD operations CRUD operations
""" """
@ -159,7 +157,8 @@ class InfinityConnection(InfinityConnectionBase):
if table_found: if table_found:
break break
if not table_found: if not table_found:
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}") self.logger.error(
f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
return pd.DataFrame(), 0 return pd.DataFrame(), 0
for matchExpr in match_expressions: for matchExpr in match_expressions:
@ -280,7 +279,8 @@ class InfinityConnection(InfinityConnectionBase):
try: try:
table_instance = db_instance.get_table(table_name) table_instance = db_instance.get_table(table_name)
except Exception: except Exception:
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.") self.logger.warning(
f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
continue continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df() kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
@ -288,7 +288,9 @@ class InfinityConnection(InfinityConnectionBase):
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"]) res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist()) fields = set(res.columns.tolist())
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]: for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd",
"question_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks",
"authors_sm_tks"]:
fields.add(field) fields.add(field)
res_fields = self.get_fields(res, list(fields)) res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunk_id, None) return res_fields.get(chunk_id, None)
@ -379,7 +381,9 @@ class InfinityConnection(InfinityConnectionBase):
d[k] = "_".join(f"{num:08x}" for num in v) d[k] = "_".join(f"{num:08x}" for num in v)
else: else:
d[k] = v d[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]: for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd",
"question_tks"]:
if k in d: if k in d:
del d[k] del d[k]
@ -478,7 +482,8 @@ class InfinityConnection(InfinityConnectionBase):
del new_value[k] del new_value[k]
else: else:
new_value[k] = v new_value[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]: for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in new_value: if k in new_value:
del new_value[k] del new_value[k]
@ -502,7 +507,8 @@ class InfinityConnection(InfinityConnectionBase):
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.") self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
for update_kv, ids in remove_opt.items(): for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv) k, v = json.loads(update_kv)
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)}) table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])),
{k: "###".join(v)})
table_instance.update(filter, new_value) table_instance.update(filter, new_value)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)

View File

@ -46,6 +46,7 @@ class RAGFlowMinio:
# pass original identifier forward for use by other decorators # pass original identifier forward for use by other decorators
kwargs['_orig_bucket'] = original_bucket kwargs['_orig_bucket'] = original_bucket
return method(self, actual_bucket, *args, **kwargs) return method(self, actual_bucket, *args, **kwargs)
return wrapper return wrapper
@staticmethod @staticmethod
@ -71,6 +72,7 @@ class RAGFlowMinio:
fnm = f"{orig_bucket}/{fnm}" fnm = f"{orig_bucket}/{fnm}"
return method(self, bucket, fnm, *args, **kwargs) return method(self, bucket, fnm, *args, **kwargs)
return wrapper return wrapper
def __open__(self): def __open__(self):

View File

@ -37,7 +37,8 @@ from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton from common.decorator import singleton
from common.float_utils import get_float from common.float_utils import get_float
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
MatchDenseExpr
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
ATTEMPT_TIME = 2 ATTEMPT_TIME = 2

View File

@ -6,7 +6,6 @@ from urllib.parse import quote_plus
from common.config_utils import get_base_config from common.config_utils import get_base_config
from common.decorator import singleton from common.decorator import singleton
CREATE_TABLE_SQL = """ CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS `{}` ( CREATE TABLE IF NOT EXISTS `{}` (
`key` VARCHAR(255) PRIMARY KEY, `key` VARCHAR(255) PRIMARY KEY,
@ -36,7 +35,8 @@ def get_opendal_config():
"table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"), "table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"),
"max_allowed_packet": str(max_packet) "max_allowed_packet": str(max_packet)
} }
kwargs["connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}" kwargs[
"connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
else: else:
scheme = opendal_config.get("scheme") scheme = opendal_config.get("scheme")
config_data = opendal_config.get("config", {}) config_data = opendal_config.get("config", {})
@ -99,7 +99,6 @@ class OpenDALStorage:
def obj_exist(self, bucket, fnm, tenant_id=None): def obj_exist(self, bucket, fnm, tenant_id=None):
return self._operator.exists(f"{bucket}/{fnm}") return self._operator.exists(f"{bucket}/{fnm}")
def init_db_config(self): def init_db_config(self):
try: try:
conn = pymysql.connect( conn = pymysql.connect(

View File

@ -26,7 +26,8 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout from opensearchpy import ConnectionTimeout
from common.decorator import singleton from common.decorator import singleton
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer from rag.nlp import is_english, rag_tokenizer
from common.constants import PAGERANK_FLD, TAG_FLD from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings from common import settings

View File

@ -42,6 +42,7 @@ class RAGFlowOSS:
# If there is a default bucket, use the default bucket # If there is a default bucket, use the default bucket
actual_bucket = self.bucket if self.bucket else bucket actual_bucket = self.bucket if self.bucket else bucket
return method(self, actual_bucket, *args, **kwargs) return method(self, actual_bucket, *args, **kwargs)
return wrapper return wrapper
@staticmethod @staticmethod
@ -50,6 +51,7 @@ class RAGFlowOSS:
# If the prefix path is set, use the prefix path # If the prefix path is set, use the prefix path
fnm = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm fnm = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm
return method(self, bucket, fnm, *args, **kwargs) return method(self, bucket, fnm, *args, **kwargs)
return wrapper return wrapper
def __open__(self): def __open__(self):
@ -171,4 +173,3 @@ class RAGFlowOSS:
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return None return None

View File

@ -21,7 +21,6 @@ Utility functions for Raptor processing decisions.
import logging import logging
from typing import Optional from typing import Optional
# File extensions for structured data types # File extensions for structured data types
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"} EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
CSV_EXTENSIONS = {".csv", ".tsv"} CSV_EXTENSIONS = {".csv", ".tsv"}

View File

@ -33,6 +33,7 @@ except Exception:
except Exception: except Exception:
REDIS = {} REDIS = {}
class RedisMsg: class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message): def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer self.__consumer = consumer
@ -278,7 +279,8 @@ class RedisDB:
def decrby(self, key: str, decrement: int): def decrby(self, key: str, decrement: int):
return self.REDIS.decrby(key, decrement) return self.REDIS.decrby(key, decrement)
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int: def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default",
increment: int = 1, ensure_minimum: int | None = None) -> int:
redis_key = f"{key_prefix}:{namespace}" redis_key = f"{key_prefix}:{namespace}"
try: try:

View File

@ -46,6 +46,7 @@ class RAGFlowS3:
# If there is a default bucket, use the default bucket # If there is a default bucket, use the default bucket
actual_bucket = self.bucket if self.bucket else bucket actual_bucket = self.bucket if self.bucket else bucket
return method(self, actual_bucket, *args, **kwargs) return method(self, actual_bucket, *args, **kwargs)
return wrapper return wrapper
@staticmethod @staticmethod
@ -57,6 +58,7 @@ class RAGFlowS3:
if self.prefix_path: if self.prefix_path:
fnm = f"{self.prefix_path}/{bucket}/{fnm}" fnm = f"{self.prefix_path}/{bucket}/{fnm}"
return method(self, bucket, fnm, *args, **kwargs) return method(self, bucket, fnm, *args, **kwargs)
return wrapper return wrapper
def __open__(self): def __open__(self):

View File

@ -30,7 +30,8 @@ class Tavily:
search_depth="advanced", search_depth="advanced",
max_results=6 max_results=6
) )
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]] return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res
in response["results"]]
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)