mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-25 08:06:48 +08:00
Fix: Merge main branch (#10377)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: jinhai <haijin.chn@gmail.com> Signed-off-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Lynn <lynn_inf@hotmail.com> Co-authored-by: chanx <1243304602@qq.com> Co-authored-by: balibabu <cike8899@users.noreply.github.com> Co-authored-by: 纷繁下的无奈 <zhileihuang@126.com> Co-authored-by: huangzl <huangzl@shinemo.com> Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com> Co-authored-by: Wilmer <33392318@qq.com> Co-authored-by: Adrian Weidig <adrianweidig@gmx.net> Co-authored-by: Zhichang Yu <yuzhichang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yongteng Lei <yongtengrey@outlook.com> Co-authored-by: Liu An <asiro@qq.com> Co-authored-by: buua436 <66937541+buua436@users.noreply.github.com> Co-authored-by: BadwomanCraZY <511528396@qq.com> Co-authored-by: cucusenok <31804608+cucusenok@users.noreply.github.com> Co-authored-by: Russell Valentine <russ@coldstonelabs.org> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Billy Bao <newyorkupperbay@gmail.com> Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: TensorNull <129579691+TensorNull@users.noreply.github.com> Co-authored-by: TensorNull <tensor.null@gmail.com> Co-authored-by: Ajay <160579663+aybanda@users.noreply.github.com> Co-authored-by: AB <aj@Ajays-MacBook-Air.local> Co-authored-by: 天海蒼灆 <huangaoqin@tecpie.com> Co-authored-by: He Wang <wanghechn@qq.com> Co-authored-by: Atsushi Hatakeyama <atu729@icloud.com> Co-authored-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Mohamed Mathari <155896313+melmathari@users.noreply.github.com> Co-authored-by: Mohamed Mathari <nocodeventure@Mac-mini-van-Mohamed.fritz.box> Co-authored-by: Stephen Hu <stephenhu@seismic.com> Co-authored-by: Shaun Zhang <zhangwfjh@users.noreply.github.com> Co-authored-by: zhimeng123 <60221886+zhimeng123@users.noreply.github.com> Co-authored-by: mxc <mxc@example.com> Co-authored-by: Dominik Novotný <50611433+SgtMarmite@users.noreply.github.com> Co-authored-by: EVGENY M <168018528+rjohny55@users.noreply.github.com> Co-authored-by: mcoder6425 <mcoder64@gmail.com> Co-authored-by: TeslaZY <TeslaZY@outlook.com> Co-authored-by: lemsn <lemsn@msn.com> Co-authored-by: lemsn <lemsn@126.com> Co-authored-by: Adrian Gora <47756404+adagora@users.noreply.github.com> Co-authored-by: Womsxd <45663319+Womsxd@users.noreply.github.com> Co-authored-by: FatMii <39074672+FatMii@users.noreply.github.com>
This commit is contained in:
@ -22,12 +22,15 @@ from docx import Document
|
||||
|
||||
from api.db import ParserType
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import bullets_category, remove_contents_table, hierarchical_merge, \
|
||||
make_colon_as_title, tokenize_chunks, docx_question_level
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.nlp import bullets_category, remove_contents_table, \
|
||||
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
|
||||
from rag.nlp import rag_tokenizer, Node
|
||||
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -55,49 +58,37 @@ class Docx(DocxParser):
|
||||
return [line for line in lines if line]
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
bull = bullets_category([p.text for p in self.doc.paragraphs])
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = docx_question_level(p, bull)
|
||||
if not p_text.strip("\n"):
|
||||
continue
|
||||
lines.append((question_level, p_text))
|
||||
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
|
||||
visit = [False for _ in range(len(lines))]
|
||||
sections = []
|
||||
for s in range(len(lines)):
|
||||
e = s + 1
|
||||
while e < len(lines):
|
||||
if lines[e][0] <= lines[s][0]:
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
level_set = set()
|
||||
bull = bullets_category([p.text for p in self.doc.paragraphs])
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
e += 1
|
||||
if e - s == 1 and visit[s]:
|
||||
continue
|
||||
sec = []
|
||||
next_level = lines[s][0] + 1
|
||||
while not sec and next_level < 22:
|
||||
for i in range(s+1, e):
|
||||
if lines[i][0] != next_level:
|
||||
question_level, p_text = docx_question_level(p, bull)
|
||||
if not p_text.strip("\n"):
|
||||
continue
|
||||
lines.append((question_level, p_text))
|
||||
level_set.add(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
sec.append(lines[i][1])
|
||||
visit[i] = True
|
||||
next_level += 1
|
||||
sec.insert(0, lines[s][1])
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
|
||||
sorted_levels = sorted(level_set)
|
||||
|
||||
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
|
||||
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
|
||||
|
||||
root = Node(level=0, depth=h2_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
|
||||
sections.append("\n".join(sec))
|
||||
return [s for s in sections if s]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'''
|
||||
@ -163,7 +154,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
chunks = Docx()(filename, binary)
|
||||
callback(0.7, "Finish parsing.")
|
||||
return tokenize_chunks(chunks, doc, eng, None)
|
||||
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
@ -172,7 +163,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
from_page=from_page, to_page=to_page, callback=callback)[0]:
|
||||
sections.append(txt + poss)
|
||||
|
||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
@ -203,13 +194,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
make_colon_as_title(sections)
|
||||
bull = bullets_category(sections)
|
||||
chunks = hierarchical_merge(bull, sections, 5)
|
||||
if not chunks:
|
||||
res = tree_merge(bull, sections, 2)
|
||||
|
||||
|
||||
if not res:
|
||||
callback(0.99, "No chunk parsed out.")
|
||||
|
||||
return tokenize_chunks(["\n".join(ck)
|
||||
for ck in chunks], doc, eng, pdf_parser)
|
||||
return tokenize_chunks(res, doc, eng, pdf_parser)
|
||||
|
||||
# chunks = hierarchical_merge(bull, sections, 5)
|
||||
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
@ -138,6 +138,8 @@ def label_question(question, kbs):
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
if not tag_kbs:
|
||||
return tags
|
||||
tags = settings.retrievaler.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
|
||||
@ -56,6 +56,6 @@ class ProcessBase(ComponentBase):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
async def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -241,7 +241,6 @@ class Chunker(ProcessBase):
|
||||
"laws": self._laws,
|
||||
"presentation": self._presentation,
|
||||
"one": self._one,
|
||||
"toc": self._toc,
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
@ -68,6 +68,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
||||
SupportedLiteLLMProvider.Lingyi_AI: "https://api.lingyiwanwu.com/v1",
|
||||
SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/",
|
||||
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
|
||||
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ from zhipuai import ZhipuAI
|
||||
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.nlp import is_chinese, is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
# Error message constants
|
||||
@ -143,9 +143,10 @@ class Base(ABC):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwen3") >= 0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
@ -155,10 +156,12 @@ class Base(ABC):
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
reasoning_start = False
|
||||
|
||||
if kwargs.get("stop") or "stop" in gen_conf:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
||||
else:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
@ -190,21 +193,30 @@ class Base(ABC):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
return ans + LENGTH_NOTIFICATION_EN
|
||||
|
||||
def _exceptions(self, e, attempt):
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
}
|
||||
|
||||
def _should_retry(self, error_code: str) -> bool:
|
||||
return error_code in self._retryable_errors
|
||||
|
||||
def _exceptions(self, e, attempt) -> str | None:
|
||||
logging.exception("OpenAI chat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
|
||||
if not should_retry:
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
if self._should_retry(error_code):
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
return None
|
||||
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
@ -445,15 +457,7 @@ class Base(ABC):
|
||||
yield total_tokens
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
@ -541,6 +545,14 @@ class AzureChat(Base):
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
LLMErrorCode.ERROR_QUOTA,
|
||||
}
|
||||
|
||||
|
||||
class BaiChuanChat(Base):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
@ -629,6 +641,10 @@ class ZhipuChat(Base):
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
return gen_conf
|
||||
|
||||
def _clean_conf_plealty(self, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
@ -636,22 +652,14 @@ class ZhipuChat(Base):
|
||||
return gen_conf
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
|
||||
return super().chat_with_tools(system, history, gen_conf)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
@ -677,11 +685,7 @@ class ZhipuChat(Base):
|
||||
yield tk_count
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
return super().chat_streamly_with_tools(system, history, gen_conf)
|
||||
|
||||
|
||||
@ -858,6 +862,7 @@ class MistralChat(Base):
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
|
||||
ans = response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
@ -870,9 +875,7 @@ class MistralChat(Base):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
@ -1302,10 +1305,6 @@ class LiteLLMBase(ABC):
|
||||
"302.AI",
|
||||
]
|
||||
|
||||
import litellm
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
self.provider = kwargs.get("provider", "")
|
||||
@ -1429,21 +1428,30 @@ class LiteLLMBase(ABC):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
return ans + LENGTH_NOTIFICATION_EN
|
||||
|
||||
def _exceptions(self, e, attempt):
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
}
|
||||
|
||||
def _should_retry(self, error_code: str) -> bool:
|
||||
return error_code in self._retryable_errors
|
||||
|
||||
def _exceptions(self, e, attempt) -> str | None:
|
||||
logging.exception("OpenAI chat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
|
||||
if not should_retry:
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
if self._should_retry(error_code):
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
return None
|
||||
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
|
||||
@ -25,7 +25,7 @@ from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
from rag.nlp import is_english
|
||||
from rag.prompts import vision_llm_describe_prompt
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ from zhipuai import ZhipuAI
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -52,15 +52,7 @@ class Base(ABC):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
@ -497,7 +489,6 @@ class MistralEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
import time
|
||||
import random
|
||||
|
||||
retry_max = 5
|
||||
while retry_max > 0:
|
||||
try:
|
||||
@ -662,7 +653,7 @@ class OpenAI_APIEmbed(OpenAIEmbed):
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
#base_url = urljoin(base_url, "v1")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
@ -945,6 +936,7 @@ class GiteeEmbed(SILICONFLOWEmbed):
|
||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class DeepInfraEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
@ -963,7 +955,7 @@ class Ai302Embed(Base):
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class CometEmbed(OpenAIEmbed):
|
||||
class CometAPIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
|
||||
|
||||
@ -30,7 +30,7 @@ from yarl import URL
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
@ -44,18 +44,7 @@ class Base(ABC):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
@ -365,7 +354,7 @@ class OpenAI_APIRerank(Base):
|
||||
max_rank = np.max(rank)
|
||||
|
||||
# Avoid division by zero if all ranks are identical
|
||||
if np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
|
||||
@ -236,7 +236,7 @@ class DeepInfraSeq2txt(Base):
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class CometSeq2txt(Base):
|
||||
class CometAPISeq2txt(Base):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.cometapi.com/v1", **kwargs):
|
||||
|
||||
@ -189,6 +189,13 @@ BULLET_PATTERN = [[
|
||||
r"Chapter (I+V?|VI*|XI|IX|X)",
|
||||
r"Section [0-9]+",
|
||||
r"Article [0-9]+"
|
||||
], [
|
||||
r"^#[^#]",
|
||||
r"^##[^#]",
|
||||
r"^###.*",
|
||||
r"^####.*",
|
||||
r"^#####.*",
|
||||
r"^######.*",
|
||||
]
|
||||
]
|
||||
|
||||
@ -429,8 +436,58 @@ def not_title(txt):
|
||||
return True
|
||||
return re.search(r"[,;,。;!!]", txt)
|
||||
|
||||
def tree_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return sections
|
||||
if isinstance(sections[0], type("")):
|
||||
sections = [(s, "") for s in sections]
|
||||
|
||||
# filter out position information in pdf sections
|
||||
sections = [(t, o) for t, o in sections if
|
||||
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
|
||||
|
||||
def get_level(bull, section):
|
||||
text, layout = section
|
||||
text = re.sub(r"\u3000", " ", text).strip()
|
||||
|
||||
for i, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, text.strip()):
|
||||
return i+1, text
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(text):
|
||||
return len(BULLET_PATTERN[bull])+1, text
|
||||
else:
|
||||
return len(BULLET_PATTERN[bull])+2, text
|
||||
|
||||
level_set = set()
|
||||
lines = []
|
||||
for section in sections:
|
||||
level, text = get_level(bull, section)
|
||||
|
||||
if not text.strip("\n"):
|
||||
continue
|
||||
|
||||
lines.append((level, text))
|
||||
level_set.add(level)
|
||||
|
||||
sorted_levels = sorted(list(level_set))
|
||||
|
||||
if depth <= len(sorted_levels):
|
||||
target_level = sorted_levels[depth - 1]
|
||||
else:
|
||||
target_level = sorted_levels[-1]
|
||||
|
||||
if target_level == len(BULLET_PATTERN[bull]) + 2:
|
||||
target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0]
|
||||
|
||||
root = Node(level=0, depth=target_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
@ -632,7 +689,7 @@ def docx_question_level(p, bull=-1):
|
||||
for j, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, txt):
|
||||
return j + 1, txt
|
||||
return len(BULLET_PATTERN[bull]), txt
|
||||
return len(BULLET_PATTERN[bull])+1, txt
|
||||
|
||||
|
||||
def concat_img(img1, img2):
|
||||
@ -735,3 +792,68 @@ def get_delimiters(delimiters: str):
|
||||
dels_pattern = "|".join(dels)
|
||||
|
||||
return dels_pattern
|
||||
|
||||
class Node:
|
||||
def __init__(self, level, depth=-1, texts=None):
|
||||
self.level = level
|
||||
self.depth = depth
|
||||
self.texts = texts if texts is not None else [] # 存放内容
|
||||
self.children = [] # 子节点
|
||||
|
||||
def add_child(self, child_node):
|
||||
self.children.append(child_node)
|
||||
|
||||
def get_children(self):
|
||||
return self.children
|
||||
|
||||
def get_level(self):
|
||||
return self.level
|
||||
|
||||
def get_texts(self):
|
||||
return self.texts
|
||||
|
||||
def set_texts(self, texts):
|
||||
self.texts = texts
|
||||
|
||||
def add_text(self, text):
|
||||
self.texts.append(text)
|
||||
|
||||
def clear_text(self):
|
||||
self.texts = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
|
||||
|
||||
def build_tree(self, lines):
|
||||
stack = [self]
|
||||
for line in lines:
|
||||
level, text = line
|
||||
node = Node(level=level, texts=[text])
|
||||
|
||||
if level <= self.depth or self.depth == -1:
|
||||
while stack and level <= stack[-1].get_level():
|
||||
stack.pop()
|
||||
|
||||
stack[-1].add_child(node)
|
||||
stack.append(node)
|
||||
else:
|
||||
stack[-1].add_text(text)
|
||||
return self
|
||||
|
||||
def get_tree(self):
|
||||
tree_list = []
|
||||
self._dfs(self, tree_list, 0, [])
|
||||
return tree_list
|
||||
|
||||
def _dfs(self, node, tree_list, current_depth, titles):
|
||||
|
||||
if node.get_texts():
|
||||
if 0 < node.get_level() < self.depth:
|
||||
titles.extend(node.get_texts())
|
||||
else:
|
||||
combined_text = ["\n".join(titles + node.get_texts())]
|
||||
tree_list.append(combined_text)
|
||||
|
||||
|
||||
for child in node.get_children():
|
||||
self._dfs(child, tree_list, current_depth + 1, titles.copy())
|
||||
|
||||
@ -56,7 +56,7 @@ class FulltextQueryer:
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from . import prompts
|
||||
from . import generator
|
||||
|
||||
__all__ = [name for name in dir(prompts)
|
||||
__all__ = [name for name in dir(generator)
|
||||
if not name.startswith('_')]
|
||||
|
||||
globals().update({name: getattr(prompts, name) for name in __all__})
|
||||
globals().update({name: getattr(generator, name) for name in __all__})
|
||||
@ -22,7 +22,7 @@ from typing import Tuple
|
||||
import jinja2
|
||||
import json_repair
|
||||
from api.utils import hash_str2int
|
||||
from rag.prompts.prompt_template import load_prompt
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
import os
|
||||
import logging
|
||||
from api.utils import get_base_config, decrypt_database_config
|
||||
from api.utils.configs import get_base_config, decrypt_database_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
# Server
|
||||
|
||||
@ -88,6 +88,20 @@ def num_tokens_from_string(string: str) -> int:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
|
||||
@ -108,6 +108,19 @@ class RAGFlowMinio:
|
||||
logging.exception(f"obj_exist {bucket}/{filename} got exception")
|
||||
return False
|
||||
|
||||
def bucket_exists(self, bucket):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
except S3Error as e:
|
||||
if e.code in ["NoSuchKey", "NoSuchBucket", "ResourceNotFound"]:
|
||||
return False
|
||||
except Exception:
|
||||
logging.exception(f"bucket_exist {bucket} got exception")
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import pymysql
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from api.utils import get_base_config
|
||||
from api.utils.configs import get_base_config
|
||||
from rag.utils import singleton
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user