mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix errors detected by Ruff (#3918)
### What problem does this PR solve? Fix errors detected by Ruff ### Type of change - [x] Refactoring
This commit is contained in:
@ -15,13 +15,14 @@
|
||||
#
|
||||
import pathlib
|
||||
import re
|
||||
from .user_service import UserService
|
||||
from .user_service import UserService as UserService
|
||||
|
||||
|
||||
def duplicate_name(query_func, **kwargs):
|
||||
fnm = kwargs["name"]
|
||||
objs = query_func(**kwargs)
|
||||
if not objs: return fnm
|
||||
if not objs:
|
||||
return fnm
|
||||
ext = pathlib.Path(fnm).suffix #.jpg
|
||||
nm = re.sub(r"%s$"%ext, "", fnm)
|
||||
r = re.search(r"\(([0-9]+)\)$", nm)
|
||||
@ -31,8 +32,8 @@ def duplicate_name(query_func, **kwargs):
|
||||
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
||||
c += 1
|
||||
nm = f"{nm}({c})"
|
||||
if ext: nm += f"{ext}"
|
||||
if ext:
|
||||
nm += f"{ext}"
|
||||
|
||||
kwargs["name"] = nm
|
||||
return duplicate_name(query_func, **kwargs)
|
||||
|
||||
|
||||
@ -64,7 +64,8 @@ class API4ConversationService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def stats(cls, tenant_id, from_date, to_date, source=None):
|
||||
if len(to_date) == 10: to_date += " 23:59:59"
|
||||
if len(to_date) == 10:
|
||||
to_date += " 23:59:59"
|
||||
return cls.model.select(
|
||||
cls.model.create_date.truncate("day").alias("dt"),
|
||||
peewee.fn.COUNT(
|
||||
|
||||
@ -13,9 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
import peewee
|
||||
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
|
||||
from api.db.db_models import DB, CanvasTemplate, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ class CommonService:
|
||||
try:
|
||||
obj = cls.model.query(id=pid)[0]
|
||||
return True, obj
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -106,15 +106,15 @@ def message_fit_in(msg, max_length=4000):
|
||||
return c, msg
|
||||
|
||||
ll = num_tokens_from_string(msg_[0]["content"])
|
||||
l = num_tokens_from_string(msg_[-1]["content"])
|
||||
if ll / (ll + l) > 0.8:
|
||||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||
if ll / (ll + ll2) > 0.8:
|
||||
m = msg_[0]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
msg[0]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
m = msg_[1]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
msg[1]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
@ -257,7 +257,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
|
||||
refs = deepcopy(kbinfos)
|
||||
@ -433,13 +434,15 @@ def relevant(tenant_id, llm_id, question, contents: list):
|
||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
||||
No other words needed except 'yes' or 'no'.
|
||||
"""
|
||||
if not contents:return False
|
||||
if not contents:
|
||||
return False
|
||||
contents = "Documents: \n" + " - ".join(contents)
|
||||
contents = f"Question: {question}\n" + contents
|
||||
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
||||
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
||||
if ans.lower().find("yes") >= 0: return True
|
||||
if ans.lower().find("yes") >= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -481,8 +484,10 @@ Requirements:
|
||||
]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple): kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >=0: return ""
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >=0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
@ -508,8 +513,10 @@ Requirements:
|
||||
]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple): kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >= 0: return ""
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
@ -520,7 +527,8 @@ def full_question(tenant_id, llm_id, messages):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
conv = []
|
||||
for m in messages:
|
||||
if m["role"] not in ["user", "assistant"]: continue
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
continue
|
||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||
conv = "\n".join(conv)
|
||||
today = datetime.date.today().isoformat()
|
||||
@ -581,7 +589,8 @@ Output: What's the weather in Rochester on {tomorrow}?
|
||||
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text: return
|
||||
if not tts_mdl or not text:
|
||||
return
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
@ -641,7 +650,8 @@ def ask(question, kb_ids, tenant_id):
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
refs = deepcopy(kbinfos)
|
||||
for c in refs["chunks"]:
|
||||
|
||||
@ -532,7 +532,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
try:
|
||||
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
|
||||
ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
|
||||
@ -20,7 +20,7 @@ from api.db.db_models import DB
|
||||
from api.db.db_models import File, File2Document
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
|
||||
class File2DocumentService(CommonService):
|
||||
@ -63,7 +63,7 @@ class File2DocumentService(CommonService):
|
||||
def update_by_file_id(cls, file_id, obj):
|
||||
obj["update_time"] = current_timestamp()
|
||||
obj["update_date"] = datetime_format(datetime.now())
|
||||
num = cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||
e, obj = cls.get_by_id(cls.model.id)
|
||||
return obj
|
||||
|
||||
|
||||
@ -85,7 +85,8 @@ class FileService(CommonService):
|
||||
.join(Document, on=(File2Document.document_id == Document.id))
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
||||
.where(cls.model.id == file_id))
|
||||
if not kbs: return []
|
||||
if not kbs:
|
||||
return []
|
||||
kbs_info_list = []
|
||||
for kb in list(kbs.dicts()):
|
||||
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
|
||||
@ -304,7 +305,8 @@ class FileService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
|
||||
for _ in File2DocumentService.get_by_document_id(doc["id"]): return
|
||||
for _ in File2DocumentService.get_by_document_id(doc["id"]):
|
||||
return
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": kb_folder_id,
|
||||
|
||||
@ -107,7 +107,8 @@ class TenantLLMService(CommonService):
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if model_config: model_config = model_config.to_dict()
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
|
||||
@ -57,28 +57,33 @@ class TaskService(CommonService):
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
|
||||
cls.model.update_time,
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == task_id)
|
||||
)
|
||||
docs = list(docs.dicts())
|
||||
if not docs: return None
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
msg = "\nTask has been received."
|
||||
prog = random.random() / 10.
|
||||
prog = random.random() / 10.0
|
||||
if docs[0]["retry_count"] >= 3:
|
||||
msg = "\nERROR: Task is abandoned after 3 times attempts."
|
||||
prog = -1
|
||||
|
||||
cls.model.update(progress_msg=cls.model.progress_msg + msg,
|
||||
progress=prog,
|
||||
retry_count=docs[0]["retry_count"]+1
|
||||
).where(
|
||||
cls.model.id == docs[0]["id"]).execute()
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + msg,
|
||||
progress=prog,
|
||||
retry_count=docs[0]["retry_count"] + 1,
|
||||
).where(cls.model.id == docs[0]["id"]).execute()
|
||||
|
||||
if docs[0]["retry_count"] >= 3: return None
|
||||
if docs[0]["retry_count"] >= 3:
|
||||
return None
|
||||
|
||||
return docs[0]
|
||||
|
||||
@ -86,21 +91,44 @@ class TaskService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_ongoing_doc_name(cls):
|
||||
with DB.lock("get_task", -1):
|
||||
docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
|
||||
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
||||
.join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
|
||||
.join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \
|
||||
docs = (
|
||||
cls.model.select(
|
||||
*[Document.id, Document.kb_id, Document.location, File.parent_id]
|
||||
)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(
|
||||
File2Document,
|
||||
on=(File2Document.document_id == Document.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.join(
|
||||
File,
|
||||
on=(File2Document.file_id == File.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.where(
|
||||
Document.status == StatusEnum.VALID.value,
|
||||
Document.run == TaskStatus.RUNNING.value,
|
||||
~(Document.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress < 1,
|
||||
cls.model.create_time >= current_timestamp() - 1000 * 600
|
||||
cls.model.create_time >= current_timestamp() - 1000 * 600,
|
||||
)
|
||||
)
|
||||
docs = list(docs.dicts())
|
||||
if not docs: return []
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
|
||||
return list(
|
||||
set(
|
||||
[
|
||||
(
|
||||
d["parent_id"] if d["parent_id"] else d["kb_id"],
|
||||
d["location"],
|
||||
)
|
||||
for d in docs
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -118,28 +146,30 @@ class TaskService(CommonService):
|
||||
def update_progress(cls, id, info):
|
||||
if os.environ.get("MACOS"):
|
||||
if info["progress_msg"]:
|
||||
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||
).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
cls.model.update(progress=info["progress"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.id == id
|
||||
).execute()
|
||||
return
|
||||
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||
).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
cls.model.update(progress=info["progress"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.id == id
|
||||
).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
def new_task():
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"]
|
||||
}
|
||||
return {"id": get_uuid(), "doc_id": doc["id"]}
|
||||
|
||||
tsks = []
|
||||
|
||||
if doc["type"] == FileType.PDF.value:
|
||||
@ -150,8 +180,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
if doc["parser_id"] == "paper":
|
||||
page_size = doc["parser_config"].get("task_page_size", 22)
|
||||
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
|
||||
page_size = 10 ** 9
|
||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
||||
page_size = 10**9
|
||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
|
||||
for s, e in page_ranges:
|
||||
s -= 1
|
||||
s = max(0, s)
|
||||
@ -177,4 +207,6 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
|
||||
for t in tsks:
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
||||
assert REDIS_CONN.queue_product(
|
||||
SVR_QUEUE_NAME, message=t
|
||||
), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
@ -22,7 +22,7 @@ from api.db import UserTenantRole
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import User, Tenant
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
||||
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||
from api.db import StatusEnum
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user