Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu
2024-12-08 14:21:12 +08:00
committed by GitHub
parent e267a026f3
commit 0d68a6cd1b
97 changed files with 2558 additions and 1976 deletions

View File

@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
@ -102,7 +102,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
@ -112,7 +112,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")

View File

@ -75,7 +75,7 @@ def chunk(
_add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(l, "") for l in HtmlParser.parser_txt("\n".join(html_txt)) if l
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line
]
st = timer()

View File

@ -18,7 +18,8 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
)
for c in chunks: c["docnm_kwd"] = filename
for c in chunks:
c["docnm_kwd"] = filename
doc = {
"docnm_kwd": filename,

View File

@ -48,7 +48,7 @@ class Docx(DocxParser):
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
return [l for l in lines if l]
return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
@ -60,7 +60,8 @@ class Docx(DocxParser):
if pn > to_page:
break
question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):continue
if not p_text.strip("\n"):
continue
lines.append((question_level, p_text))
for run in p.runs:
@ -78,19 +79,21 @@ class Docx(DocxParser):
if lines[e][0] <= lines[s][0]:
break
e += 1
if e - s == 1 and visit[s]: continue
if e - s == 1 and visit[s]:
continue
sec = []
next_level = lines[s][0] + 1
while not sec and next_level < 22:
for i in range(s+1, e):
if lines[i][0] != next_level: continue
if lines[i][0] != next_level:
continue
sec.append(lines[i][1])
visit[i] = True
next_level += 1
sec.insert(0, lines[s][1])
sections.append("\n".join(sec))
return [l for l in sections if l]
return [s for s in sections if s]
def __str__(self) -> str:
return f'''
@ -168,13 +171,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
@ -182,7 +185,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:

View File

@ -190,7 +190,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3:
sections = [(t, l, [[0] * 5]) for t, l in sections]
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
# set pivot using the most frequent type of title,
# then merge between 2 pivot
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
@ -211,7 +211,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
else:
bull = bullets_category([txt for txt, _, _ in sections])
most_level, levels = title_frequency(
bull, [(txt, l) for txt, l, poss in sections])
bull, [(txt, lvl) for txt, lvl, _ in sections])
assert len(sections) == len(levels)
sec_ids = []
@ -225,7 +225,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls:
if not rows: continue
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))

View File

@ -54,7 +54,8 @@ class Pdf(PdfParser):
sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls:
if not rows:continue
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
@ -109,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:

View File

@ -171,7 +171,7 @@ class Pdf(PdfParser):
tbl_bottom = tbls[tbl_index][1][0][4]
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
tbl_text = ''.join(tbls[tbl_index][0][1])
_tbl_text = ''.join(tbls[tbl_index][0][1])
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag,
@ -325,9 +325,11 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for l in lines:
if len(l.split(",")) == 2: comma += 1
if len(l.split("\t")) == 2: tab += 1
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","
fails = []
@ -336,18 +338,21 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
if question: answer += "\n" + lines[i]
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i+1))
elif len(arr) == 2:
if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng))
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
question, answer = arr
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question: res.append(beAdoc(deepcopy(doc), question, answer, eng))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@ -367,19 +372,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
last_question, last_answer = "", ""
_last_question, last_answer = "", ""
question_stack, level_stack = [], []
code_block = False
level_index = [-1] * 7
for index, l in enumerate(lines):
if l.strip().startswith('```'):
for index, line in enumerate(lines):
if line.strip().startswith('```'):
code_block = not code_block
question_level, question = 0, ''
if not code_block:
question_level, question = mdQuestionLevel(l)
question_level, question = mdQuestionLevel(line)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{l}'
last_answer = f'{last_answer}\n{line}'
else: # is a question
if last_answer.strip():
sum_question = '\n'.join(question_stack)

View File

@ -41,14 +41,16 @@ class Excel(ExcelParser):
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows:continue
if not rows:
continue
headers = [cell.value for cell in rows[0]]
missed = set([i for i, h in enumerate(headers) if h is None])
headers = [
cell.value for i,
cell in enumerate(
rows[0]) if i not in missed]
if not headers:continue
if not headers:
continue
data = []
for i, r in enumerate(rows[1:]):
rn += 1
@ -88,7 +90,6 @@ def trans_bool(s):
def column_data_type(arr):
arr = list(arr)
uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
@ -157,7 +158,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
continue
if i >= to_page:
break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue

View File

@ -13,12 +13,124 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .embedding_model import *
from .chat_model import *
from .cv_model import *
from .rerank_model import *
from .sequence2txt_model import *
from .tts_model import *
from .embedding_model import (
OllamaEmbed,
LocalAIEmbed,
OpenAIEmbed,
AzureEmbed,
XinferenceEmbed,
QWenEmbed,
ZhipuEmbed,
FastEmbed,
YoudaoEmbed,
BaiChuanEmbed,
JinaEmbed,
DefaultEmbedding,
MistralEmbed,
BedrockEmbed,
GeminiEmbed,
NvidiaEmbed,
LmStudioEmbed,
OpenAI_APIEmbed,
CoHereEmbed,
TogetherAIEmbed,
PerfXCloudEmbed,
UpstageEmbed,
SILICONFLOWEmbed,
ReplicateEmbed,
BaiduYiyanEmbed,
VoyageEmbed,
HuggingFaceEmbed,
VolcEngineEmbed,
)
from .chat_model import (
GptTurbo,
AzureChat,
ZhipuChat,
QWenChat,
OllamaChat,
LocalAIChat,
XinferenceChat,
MoonshotChat,
DeepSeekChat,
VolcEngineChat,
BaiChuanChat,
MiniMaxChat,
MistralChat,
GeminiChat,
BedrockChat,
GroqChat,
OpenRouterChat,
StepFunChat,
NvidiaChat,
LmStudioChat,
OpenAI_APIChat,
CoHereChat,
LeptonAIChat,
TogetherAIChat,
PerfXCloudChat,
UpstageChat,
NovitaAIChat,
SILICONFLOWChat,
YiChat,
ReplicateChat,
HunyuanChat,
SparkChat,
BaiduYiyanChat,
AnthropicChat,
GoogleChat,
HuggingFaceChat,
)
from .cv_model import (
GptV4,
AzureGptV4,
OllamaCV,
XinferenceCV,
QWenCV,
Zhipu4V,
LocalCV,
GeminiCV,
OpenRouterCV,
LocalAICV,
NvidiaCV,
LmStudioCV,
StepFunCV,
OpenAI_APICV,
TogetherAICV,
YiCV,
HunyuanCV,
)
from .rerank_model import (
LocalAIRerank,
DefaultRerank,
JinaRerank,
YoudaoRerank,
XInferenceRerank,
NvidiaRerank,
LmStudioRerank,
OpenAI_APIRerank,
CoHereRerank,
TogetherAIRerank,
SILICONFLOWRerank,
BaiduYiyanRerank,
VoyageRerank,
QWenRerank,
)
from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
AzureSeq2txt,
XinferenceSeq2txt,
TencentCloudSeq2txt,
)
from .tts_model import (
FishAudioTTS,
QwenTTS,
OpenAITTS,
SparkTTS,
XinferenceTTS,
)
EmbeddingModel = {
"Ollama": OllamaEmbed,
@ -48,7 +160,7 @@ EmbeddingModel = {
"BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed,
"VolcEngine":VolcEngineEmbed,
"VolcEngine": VolcEngineEmbed,
}
CvModel = {
@ -68,7 +180,7 @@ CvModel = {
"OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV,
"01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV
"Tencent Hunyuan": HunyuanCV,
}
ChatModel = {
@ -111,7 +223,7 @@ ChatModel = {
}
RerankModel = {
"LocalAI":LocalAIRerank,
"LocalAI": LocalAIRerank,
"BAAI": DefaultRerank,
"Jina": JinaRerank,
"Youdao": YoudaoRerank,
@ -132,7 +244,7 @@ Seq2txtModel = {
"Tongyi-Qianwen": QWenSeq2txt,
"Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt
"Tencent Cloud": TencentCloudSeq2txt,
}
TTSModel = {

View File

@ -69,7 +69,8 @@ class Base(ABC):
stream=True,
**gen_conf)
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@ -81,7 +82,8 @@ class Base(ABC):
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
else:
total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -98,13 +100,15 @@ class Base(ABC):
class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url)
class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url = "https://api.moonshot.cn/v1"
if not base_url:
base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url)
@ -128,7 +132,8 @@ class HuggingFaceChat(Base):
class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url = "https://api.deepseek.com/v1"
if not base_url:
base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)
@ -202,7 +207,8 @@ class BaiChuanChat(Base):
stream=True,
**self._format_params(gen_conf))
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@ -313,8 +319,10 @@ class ZhipuChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
try:
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
@ -333,8 +341,10 @@ class ZhipuChat(Base):
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
tk_count = 0
try:
@ -345,7 +355,8 @@ class ZhipuChat(Base):
**gen_conf
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
@ -354,7 +365,8 @@ class ZhipuChat(Base):
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -372,11 +384,16 @@ class OllamaChat(Base):
history.insert(0, {"role": "system", "content": system})
try:
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
@ -392,11 +409,16 @@ class OllamaChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
@ -636,7 +658,8 @@ class MistralChat(Base):
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content: continue
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
@ -1196,7 +1219,8 @@ class SparkChat(Base):
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
if model_name in model2version:
model_version = model2version[model_name]
else: model_version = model_name
else:
model_version = model_name
super().__init__(key, model_version, base_url)
@ -1281,8 +1305,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
try:
@ -1312,8 +1338,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
total_tokens = 0

View File

@ -25,6 +25,7 @@ import base64
from io import BytesIO
import json
import requests
from transformers import GenerationConfig
from rag.nlp import is_english
from api.utils import get_uuid
@ -77,14 +78,16 @@ class Base(ABC):
stream=True
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -99,7 +102,7 @@ class Base(ABC):
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
except Exception:
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
@ -139,7 +142,8 @@ class Base(ABC):
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url:
base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@ -149,7 +153,8 @@ class GptV4(Base):
prompt = self.prompt(b64)
for i in range(len(prompt)):
for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text"
if "text" in c:
c["type"] = "text"
res = self.client.chat.completions.create(
model=self.model_name,
@ -171,7 +176,8 @@ class AzureGptV4(Base):
prompt = self.prompt(b64)
for i in range(len(prompt)):
for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text"
if "text" in c:
c["type"] = "text"
res = self.client.chat.completions.create(
model=self.model_name,
@ -344,14 +350,16 @@ class Zhipu4V(Base):
stream=True
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -389,11 +397,16 @@ class OllamaCV(Base):
if his["role"] == "user":
his["images"] = [image]
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
@ -414,11 +427,16 @@ class OllamaCV(Base):
if his["role"] == "user":
his["images"] = [image]
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
@ -469,7 +487,7 @@ class XinferenceCV(Base):
class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel, GenerationConfig
from google.generativeai import client, GenerativeModel
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
@ -503,7 +521,7 @@ class GeminiCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
@ -519,7 +537,6 @@ class GeminiCV(Base):
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
ans = ""
tk_count = 0
try:
for his in history:
if his["role"] == "assistant":
@ -529,14 +546,15 @@ class GeminiCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)), stream=True)
for resp in response:
if not resp.text: continue
if not resp.text:
continue
ans += resp.text
yield ans
except Exception as e:
@ -632,7 +650,8 @@ class NvidiaCV(Base):
class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url: base_url="https://api.stepfun.com/v1"
if not base_url:
base_url="https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang

View File

@ -15,12 +15,9 @@
#
import requests
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from openai import OpenAI
import os
import json
from rag.utils import num_tokens_from_string
import base64
@ -49,7 +46,8 @@ class Base(ABC):
class GPTSeq2txt(Base):
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name

View File

@ -16,7 +16,6 @@
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
@ -175,7 +174,8 @@ class QwenTTS(Base):
class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url

View File

@ -222,7 +222,8 @@ def bullets_category(sections):
def is_english(texts):
eng = 0
if not texts: return False
if not texts:
return False
for t in texts:
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
eng += 1
@ -250,7 +251,8 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = []
# wrap up as es documents
for ck in chunks:
if len(ck.strip()) == 0:continue
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
if pdf_parser:
@ -269,7 +271,8 @@ def tokenize_chunks_docx(chunks, doc, eng, images):
res = []
# wrap up as es documents
for ck, image in zip(chunks, images):
if len(ck.strip()) == 0:continue
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
@ -288,8 +291,10 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
if img: d["image"] = img
if poss: add_positions(d, poss)
if img:
d["image"] = img
if poss:
add_positions(d, poss)
res.append(d)
continue
de = "; " if eng else " "
@ -387,9 +392,9 @@ def title_frequency(bull, sections):
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size+1
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if l <= bullets_size:
most_level = l
for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if level <= bullets_size:
most_level = level
break
return most_level, levels
@ -504,7 +509,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。"):
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos: pos = ""
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num

View File

@ -121,7 +121,8 @@ class FulltextQueryer:
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32: keywords.extend(syns)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
@ -147,7 +148,8 @@ class FulltextQueryer:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32: keywords.extend([s for s in tk_syns if s])
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]

View File

@ -104,7 +104,6 @@ class RagTokenizer:
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist):
MAX_L = 10
res = s
# if s > MAX_L or s>= len(chars):
if s >= len(chars):
@ -184,12 +183,6 @@ class RagTokenizer:
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
patts = [
(r"[ ]+", " "),
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
]
# for p,s in patts: tks = re.sub(p, s, tks)
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split()
@ -284,7 +277,8 @@ class RagTokenizer:
same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0: res.append(" ".join(tks[j: j + same]))
if same > 0:
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1

View File

@ -62,10 +62,10 @@ class Dealer:
res = {}
f = open(fnm, "r")
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
arr = l.replace("\n", "").split("\t")
arr = line.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:

View File

@ -47,7 +47,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
if len(chunks) <= 1: return
if len(chunks) <= 1:
return
chunks = [(s, a) for s, a in chunks if len(a) > 0]
def summarize(ck_idx, lock):
@ -66,7 +67,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt])
with lock:
if not len(embds[0]): return
if not len(embds[0]):
return
chunks.append((cnt, embds[0]))
except Exception as e:
logging.exception("summarize got exception")

View File

@ -33,14 +33,16 @@ def collect():
def main():
locations = collect()
if not locations:return
if not locations:
return
logging.info(f"TASKS: {len(locations)}")
for kb_id, loc in locations:
try:
if REDIS_CONN.is_alive():
try:
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):continue
if REDIS_CONN.exist(key):
continue
file_bin = STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))

View File

@ -23,18 +23,12 @@ import os
from api.utils.log_utils import initRootLogger
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
from datetime import datetime
import json
import os
import hashlib
import copy
import re
import sys
import time
import threading
from functools import partial
@ -63,6 +57,11 @@ from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
BATCH_SIZE = 64
FACTORY = {
@ -201,7 +200,8 @@ def build_chunks(task, progress_callback):
"doc_id": task["doc_id"],
"kb_id": str(task["kb_id"])
}
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"])
el = 0
for ck in cks:
d = copy.deepcopy(doc)
@ -342,7 +342,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:

View File

@ -41,15 +41,15 @@ def findMaxDt(fnm):
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
l = l.strip("\n")
if l == 'nan':
line = line.strip("\n")
if line == 'nan':
continue
if l > m:
m = l
except Exception as e:
if line > m:
m = line
except Exception:
pass
return m
@ -59,15 +59,15 @@ def findMaxTm(fnm):
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
l = l.strip("\n")
if l == 'nan':
line = line.strip("\n")
if line == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
if int(line) > m:
m = int(line)
except Exception:
pass
return m

View File

@ -32,7 +32,7 @@ class RAGFlowAzureSasBlob(object):
self.conn = None
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
def put(self, bucket, fnm, binary):

View File

@ -36,7 +36,7 @@ class RAGFlowAzureSpnBlob(object):
self.conn = None
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))

View File

@ -132,7 +132,8 @@ class ESConnection(DocStoreConnection):
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v: continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):