mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
apply pep8 formalize (#155)
This commit is contained in:
@ -121,7 +121,9 @@ def get():
|
||||
"important_kwd")
|
||||
def set():
|
||||
req = request.json
|
||||
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = huqie.qie(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
@ -140,10 +142,16 @@ def set():
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1]
|
||||
if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.")
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(
|
||||
retmsg="Q&A must be separated by TAB/ENTER key.")
|
||||
q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
|
||||
d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q + a]))
|
||||
d = beAdoc(d, arr[0], arr[1], not any(
|
||||
[huqie.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
@ -177,7 +185,8 @@ def switch():
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
||||
if not ELASTICSEARCH.deleteByQuery(
|
||||
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
||||
return get_data_error_result(retmsg="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
||||
@ -100,7 +100,10 @@ def rm():
|
||||
def list_convsersation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
|
||||
convs = ConversationService.query(
|
||||
dialog_id=dialog_id,
|
||||
order_by=ConversationService.model.create_time,
|
||||
reverse=True)
|
||||
convs = [d.to_dict() for d in convs]
|
||||
return get_json_result(data=convs)
|
||||
except Exception as e:
|
||||
@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000):
|
||||
def count():
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
for m in msg:
|
||||
tks_cnts.append(
|
||||
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts: total += m["count"]
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
return total
|
||||
|
||||
c = count()
|
||||
if c < max_length: return c, msg
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
||||
msg_.append(msg[-1])
|
||||
msg = msg_
|
||||
c = count()
|
||||
if c < max_length: return c, msg
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
ll = num_tokens_from_string(msg_[0].content)
|
||||
l = num_tokens_from_string(msg_[-1].content)
|
||||
@ -146,8 +154,10 @@ def completion():
|
||||
req = request.json
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system": continue
|
||||
if m["role"] == "assistant" and not msg: continue
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append({"role": m["role"], "content": m["content"]})
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
@ -160,7 +170,8 @@ def completion():
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
ans = chat(dia, msg, **req)
|
||||
if not conv.reference: conv.reference = []
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.reference.append(ans["reference"])
|
||||
conv.message.append({"role": "assistant", "content": ans["answer"]})
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs):
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
## try to use sql if field mapping is good to go
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge": continue
|
||||
if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"])
|
||||
if p["key"] == "knowledge":
|
||||
continue
|
||||
if p["key"] not in kwargs and not p["optional"]:
|
||||
raise KeyError("Miss parameter: " + p["key"])
|
||||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||
prompt_config["system"] = prompt_config["system"].replace(
|
||||
"{%s}" % p["key"], " ")
|
||||
|
||||
for _ in range(len(questions)//2):
|
||||
for _ in range(len(questions) // 2):
|
||||
questions.append(questions[-1])
|
||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||
kbinfos = {"total":0, "chunks":[],"doc_aggs":[]}
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
else:
|
||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
chat_logger.info(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
return {
|
||||
"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
|
||||
kwargs["knowledge"] = "\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
|
||||
msg = [{"role": m["role"], "content": m["content"]}
|
||||
for m in messages if m["role"] != "system"]
|
||||
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
|
||||
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
||||
chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
|
||||
gen_conf["max_tokens"] = min(
|
||||
gen_conf["max_tokens"],
|
||||
llm.max_tokens - used_token_count)
|
||||
answer = chat_mdl.chat(
|
||||
prompt_config["system"].format(
|
||||
**kwargs), msg, gen_conf)
|
||||
chat_logger.info("User: {}|Assistant: {}".format(
|
||||
msg[-1]["content"], answer))
|
||||
|
||||
if knowledges:
|
||||
answer, idx = retrievaler.insert_citations(answer,
|
||||
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
||||
[ck["vector"] for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
kbinfos["doc_aggs"] = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
kbinfos["doc_aggs"] = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
for c in kbinfos["chunks"]:
|
||||
if c.get("vector"): del c["vector"]
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
return {"answer": answer, "reference": kbinfos}
|
||||
|
||||
|
||||
@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
question
|
||||
)
|
||||
tried_times = 0
|
||||
|
||||
def get_table():
|
||||
nonlocal sys_prompt, user_promt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
||||
"temperature": 0.06})
|
||||
print(user_promt, sql)
|
||||
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
@ -262,8 +290,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:continue
|
||||
if len(flds) > 11:break
|
||||
if k in forbidden_select_fields4resume:
|
||||
continue
|
||||
if len(flds) > 11:
|
||||
break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
|
||||
@ -284,13 +314,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
|
||||
问题如下:
|
||||
{}
|
||||
|
||||
|
||||
你上一次给出的错误SQL如下:
|
||||
{}
|
||||
|
||||
|
||||
后台报错如下:
|
||||
{}
|
||||
|
||||
|
||||
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
@ -302,16 +332,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
|
||||
chat_logger.info("GET table: {}".format(tbl))
|
||||
print(tbl)
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||
return None, None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
||||
docid_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(
|
||||
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
||||
|
||||
# compose markdown table
|
||||
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
||||
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
||||
("|------|" if docid_idx and docid_idx else "")
|
||||
rows = ["|" +
|
||||
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
||||
"|" for r in tbl["rows"]]
|
||||
if not docid_idx or not docnm_idx:
|
||||
chat_logger.warning("SQL missing field: " + sql)
|
||||
return "\n".join([clmns, line, "\n".join(rows)]), []
|
||||
@ -328,5 +366,5 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
return {
|
||||
"answer": "\n".join([clmns, line, rows]),
|
||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
||||
"doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
|
||||
}
|
||||
|
||||
@ -55,7 +55,8 @@ def set_dialog():
|
||||
}
|
||||
prompt_config = req.get("prompt_config", default_prompt)
|
||||
|
||||
if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"]
|
||||
if not prompt_config["system"]:
|
||||
prompt_config["system"] = default_prompt["system"]
|
||||
# if len(prompt_config["parameters"]) < 1:
|
||||
# prompt_config["parameters"] = default_prompt["parameters"]
|
||||
# for p in prompt_config["parameters"]:
|
||||
@ -63,16 +64,21 @@ def set_dialog():
|
||||
# else: prompt_config["parameters"].append(default_prompt["parameters"][0])
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["optional"]: continue
|
||||
if p["optional"]:
|
||||
continue
|
||||
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"]))
|
||||
return get_data_error_result(
|
||||
retmsg="Parameter '{}' is not used".format(p["key"]))
|
||||
|
||||
try:
|
||||
e, tenant = TenantService.get_by_id(current_user.id)
|
||||
if not e: return get_data_error_result(retmsg="Tenant not found!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
llm_id = req.get("llm_id", tenant.llm_id)
|
||||
if not dialog_id:
|
||||
if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!")
|
||||
if not req.get("kb_ids"):
|
||||
return get_data_error_result(
|
||||
retmsg="Fail! Please select knowledgebase!")
|
||||
dia = {
|
||||
"id": get_uuid(),
|
||||
"tenant_id": current_user.id,
|
||||
@ -86,17 +92,21 @@ def set_dialog():
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight
|
||||
}
|
||||
if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
if not DialogService.save(**dia):
|
||||
return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
e, dia = DialogService.get_by_id(dia["id"])
|
||||
if not e: return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
return get_json_result(data=dia.to_json())
|
||||
else:
|
||||
del req["dialog_id"]
|
||||
if "kb_names" in req: del req["kb_names"]
|
||||
if "kb_names" in req:
|
||||
del req["kb_names"]
|
||||
if not DialogService.update_by_id(dialog_id, req):
|
||||
return get_data_error_result(retmsg="Dialog not found!")
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
if not e: return get_data_error_result(retmsg="Fail to update a dialog!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Fail to update a dialog!")
|
||||
dia = dia.to_dict()
|
||||
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
||||
return get_json_result(data=dia)
|
||||
@ -110,7 +120,8 @@ def get():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
if not e: return get_data_error_result(retmsg="Dialog not found!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Dialog not found!")
|
||||
dia = dia.to_dict()
|
||||
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
||||
return get_json_result(data=dia)
|
||||
@ -122,7 +133,8 @@ def get_kb_names(kb_ids):
|
||||
ids, nms = [], []
|
||||
for kid in kb_ids:
|
||||
e, kb = KnowledgebaseService.get_by_id(kid)
|
||||
if not e or kb.status != StatusEnum.VALID.value: continue
|
||||
if not e or kb.status != StatusEnum.VALID.value:
|
||||
continue
|
||||
ids.append(kid)
|
||||
nms.append(kb.name)
|
||||
return ids, nms
|
||||
@ -132,7 +144,11 @@ def get_kb_names(kb_ids):
|
||||
@login_required
|
||||
def list():
|
||||
try:
|
||||
diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time)
|
||||
diags = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
@ -147,7 +163,8 @@ def list():
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
|
||||
DialogService.update_many_by_id(
|
||||
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -57,6 +57,9 @@ def upload():
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= 128:
|
||||
return get_data_error_result(
|
||||
retmsg="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
@ -215,9 +218,11 @@ def rm():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
|
||||
if not DocumentService.delete(doc):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document removal)!")
|
||||
@ -245,7 +250,8 @@ def run():
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -261,7 +267,8 @@ def rename():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
doc.name.lower()).suffix:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
retmsg="The extension of file can't be changed",
|
||||
@ -294,7 +301,10 @@ def get(doc_id):
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
response.headers.set('Content-Type', 'image/%s' % ext.group(1))
|
||||
else:
|
||||
response.headers.set('Content-Type', 'application/%s' % ext.group(1))
|
||||
response.headers.set(
|
||||
'Content-Type',
|
||||
'application/%s' %
|
||||
ext.group(1))
|
||||
return response
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -313,9 +323,11 @@ def change_parser():
|
||||
if "parser_config" in req:
|
||||
if req["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else: return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
|
||||
if doc.type == FileType.VISUAL or re.search(
|
||||
r"\.(ppt|pptx|pages)$", doc.name):
|
||||
return get_data_error_result(retmsg="Not supported yet!")
|
||||
|
||||
e = DocumentService.update_by_id(doc.id,
|
||||
@ -332,7 +344,8 @@ def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
||||
@ -33,15 +33,21 @@ from api.utils.api_utils import get_json_result
|
||||
def create():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||
req["name"] = duplicate_name(
|
||||
KnowledgebaseService.query,
|
||||
name=req["name"],
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)
|
||||
try:
|
||||
req["id"] = get_uuid()
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
e, t = TenantService.get_by_id(current_user.id)
|
||||
if not e: return get_data_error_result(retmsg="Tenant not found.")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Tenant not found.")
|
||||
req["embd_id"] = t.embd_id
|
||||
if not KnowledgebaseService.save(**req): return get_data_error_result()
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -54,21 +60,29 @@ def update():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
|
||||
return get_data_error_result(retmsg="Duplicated knowledgebase name.")
|
||||
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated knowledgebase name.")
|
||||
|
||||
del req["kb_id"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_data_error_result()
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase rename)!")
|
||||
|
||||
return get_json_result(data=kb.to_json())
|
||||
except Exception as e:
|
||||
@ -81,7 +95,9 @@ def detail():
|
||||
kb_id = request.args["kb_id"]
|
||||
try:
|
||||
kb = KnowledgebaseService.get_detail(kb_id)
|
||||
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
if not kb:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -96,7 +112,8 @@ def list():
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
||||
kbs = KnowledgebaseService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
||||
return get_json_result(data=kbs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -108,10 +125,15 @@ def list():
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
|
||||
if not KnowledgebaseService.update_by_id(
|
||||
req["kb_id"], {"status": StatusEnum.INVALID.value}):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase removal)!")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -48,30 +48,42 @@ def set_api_key():
|
||||
req["api_key"], llm.llm_name)
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
|
||||
if len(arr[0]) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
if not tc: raise Exception(m)
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
"temperature": 0.9})
|
||||
if not tc:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
|
||||
if msg: return get_data_error_result(retmsg=msg)
|
||||
if msg:
|
||||
return get_data_error_result(retmsg=msg)
|
||||
|
||||
llm = {
|
||||
"api_key": req["api_key"]
|
||||
}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req: llm[n] = req[n]
|
||||
if n in req:
|
||||
llm[n] = req[n]
|
||||
|
||||
if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm):
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm):
|
||||
for llm in LLMService.query(fid=factory):
|
||||
TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"])
|
||||
TenantLLMService.save(
|
||||
tenant_id=current_user.id,
|
||||
llm_factory=factory,
|
||||
llm_name=llm.llm_name,
|
||||
model_type=llm.model_type,
|
||||
api_key=req["api_key"])
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -105,17 +117,19 @@ def list():
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
||||
|
||||
res = {}
|
||||
for m in llms:
|
||||
if model_type and m["model_type"] != model_type: continue
|
||||
if m["fid"] not in res: res[m["fid"]] = []
|
||||
if model_type and m["model_type"] != model_type:
|
||||
continue
|
||||
if m["fid"] not in res:
|
||||
res[m["fid"]] = []
|
||||
res[m["fid"]].append(m)
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -40,13 +40,16 @@ def login():
|
||||
|
||||
email = request.json.get('email', "")
|
||||
users = UserService.query(email=email)
|
||||
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
if not users:
|
||||
return get_json_result(
|
||||
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
|
||||
password = request.json.get('password')
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except:
|
||||
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
||||
except BaseException:
|
||||
return get_json_result(
|
||||
data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
@ -57,7 +60,8 @@ def login():
|
||||
msg = "Welcome back!"
|
||||
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
|
||||
else:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg='Email and Password do not match!')
|
||||
|
||||
|
||||
@manager.route('/github_callback', methods=['GET'])
|
||||
@ -65,7 +69,7 @@ def github_callback():
|
||||
import requests
|
||||
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get('code')
|
||||
}, headers={"Accept": "application/json"})
|
||||
res = res.json()
|
||||
@ -96,15 +100,17 @@ def github_callback():
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
})
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
if not users:
|
||||
raise Exception('Register user failure.')
|
||||
if len(users) > 1:
|
||||
raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return redirect("/?auth=%s"%user.get_id())
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return redirect("/?error=%s"%str(e))
|
||||
return redirect("/?error=%s" % str(e))
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
@ -114,11 +120,18 @@ def github_callback():
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
headers = {"Accept": "application/json",
|
||||
'Authorization': f"token {access_token}"}
|
||||
res = requests.get(
|
||||
f"https://api.github.com/user?access_token={access_token}",
|
||||
headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
|
||||
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
|
||||
email_info = requests.get(
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers).json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email['primary'] == True),
|
||||
None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@ -138,13 +151,18 @@ def setting_user():
|
||||
request_data = request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
||||
if not check_password_hash(
|
||||
current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(
|
||||
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
||||
|
||||
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||||
if new_password:
|
||||
update_dict["password"] = generate_password_hash(
|
||||
decrypt(new_password))
|
||||
|
||||
for k in request_data.keys():
|
||||
if k in ["password", "new_password"]:continue
|
||||
if k in ["password", "new_password"]:
|
||||
continue
|
||||
update_dict[k] = request_data[k]
|
||||
|
||||
try:
|
||||
@ -152,7 +170,8 @@ def setting_user():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
return get_json_result(
|
||||
data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/info", methods=["GET"])
|
||||
@ -173,11 +192,11 @@ def rollback_user_registration(user_id):
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).excute()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def user_register(user_id, user):
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
@ -197,9 +216,14 @@ def user_register(user_id, user):
|
||||
}
|
||||
tenant_llm = []
|
||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
||||
tenant_llm.append({"tenant_id": user_id,
|
||||
"llm_factory": LLM_FACTORY,
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": API_KEY})
|
||||
|
||||
if not UserService.save(**user):return
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
@ -211,7 +235,8 @@ def user_register(user_id, user):
|
||||
def user_add():
|
||||
req = request.json
|
||||
if UserService.query(email=req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
@ -229,16 +254,19 @@ def user_add():
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_id, user_dict)
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
if not users:
|
||||
raise Exception('Register user failure.')
|
||||
if len(users) > 1:
|
||||
raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
||||
return cors_reponse(data=user.to_json(),
|
||||
auth=user.get_id(), retmsg="Welcome aboard!")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
return get_json_result(
|
||||
data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"])
|
||||
|
||||
@ -50,7 +50,13 @@ def singleton(cls, *args, **kw):
|
||||
|
||||
|
||||
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
||||
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
|
||||
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
|
||||
"create",
|
||||
"start",
|
||||
"end",
|
||||
"update",
|
||||
"read_access",
|
||||
"write_access"}
|
||||
|
||||
|
||||
class LongTextField(TextField):
|
||||
@ -73,7 +79,8 @@ class JSONField(LongTextField):
|
||||
def python_value(self, value):
|
||||
if not value:
|
||||
return self.default_value
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return utils.json_loads(
|
||||
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
|
||||
|
||||
class ListField(JSONField):
|
||||
@ -81,7 +88,8 @@ class ListField(JSONField):
|
||||
|
||||
|
||||
class SerializedField(LongTextField):
|
||||
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
def __init__(self, serialized_type=SerializedType.PICKLE,
|
||||
object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
self._serialized_type = serialized_type
|
||||
self._object_hook = object_hook
|
||||
self._object_pairs_hook = object_pairs_hook
|
||||
@ -95,7 +103,8 @@ class SerializedField(LongTextField):
|
||||
return None
|
||||
return utils.json_dumps(value, with_type=True)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
raise ValueError(
|
||||
f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
def python_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
@ -103,9 +112,11 @@ class SerializedField(LongTextField):
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return {}
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return utils.json_loads(
|
||||
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
raise ValueError(
|
||||
f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
|
||||
def is_continuous_field(cls: typing.Type) -> bool:
|
||||
@ -150,7 +161,8 @@ class BaseModel(Model):
|
||||
model_dict = self.__dict__['__data__']
|
||||
|
||||
if not only_primary_with:
|
||||
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
|
||||
return {remove_field_name_prefix(
|
||||
k): v for k, v in model_dict.items()}
|
||||
|
||||
human_model_dict = {}
|
||||
for k in self._meta.primary_key.field_names:
|
||||
@ -184,17 +196,22 @@ class BaseModel(Model):
|
||||
if is_continuous_field(type(getattr(cls, attr_name))):
|
||||
if len(f_v) == 2:
|
||||
for i, v in enumerate(f_v):
|
||||
if isinstance(v, str) and f_n in auto_date_timestamp_field():
|
||||
if isinstance(
|
||||
v, str) and f_n in auto_date_timestamp_field():
|
||||
# time type: %Y-%m-%d %H:%M:%S
|
||||
f_v[i] = utils.date_string_to_timestamp(v)
|
||||
lt_value = f_v[0]
|
||||
gt_value = f_v[1]
|
||||
if lt_value is not None and gt_value is not None:
|
||||
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
|
||||
filters.append(
|
||||
cls.getter_by(attr_name).between(
|
||||
lt_value, gt_value))
|
||||
elif lt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
|
||||
filters.append(
|
||||
operator.attrgetter(attr_name)(cls) >= lt_value)
|
||||
elif gt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
|
||||
filters.append(
|
||||
operator.attrgetter(attr_name)(cls) <= gt_value)
|
||||
else:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
|
||||
else:
|
||||
@ -205,9 +222,11 @@ class BaseModel(Model):
|
||||
if not order_by or not hasattr(cls, f"{order_by}"):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
|
||||
query_records = query_records.order_by(
|
||||
cls.getter_by(f"{order_by}").desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
|
||||
query_records = query_records.order_by(
|
||||
cls.getter_by(f"{order_by}").asc())
|
||||
return [query_record for query_record in query_records]
|
||||
else:
|
||||
return []
|
||||
@ -215,7 +234,8 @@ class BaseModel(Model):
|
||||
@classmethod
|
||||
def insert(cls, __data=None, **insert):
|
||||
if isinstance(__data, dict) and __data:
|
||||
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
|
||||
__data[cls._meta.combined["create_time"]
|
||||
] = utils.current_timestamp()
|
||||
if insert:
|
||||
insert["create_time"] = utils.current_timestamp()
|
||||
|
||||
@ -228,7 +248,8 @@ class BaseModel(Model):
|
||||
if not normalized:
|
||||
return {}
|
||||
|
||||
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
|
||||
normalized[cls._meta.combined["update_time"]
|
||||
] = utils.current_timestamp()
|
||||
|
||||
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
||||
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
||||
@ -241,7 +262,8 @@ class BaseModel(Model):
|
||||
|
||||
|
||||
class JsonSerializedField(SerializedField):
|
||||
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
|
||||
def __init__(self, object_hook=utils.from_dict_hook,
|
||||
object_pairs_hook=None, **kwargs):
|
||||
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook, **kwargs)
|
||||
|
||||
@ -251,7 +273,8 @@ class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
|
||||
self.database_connection = PooledMySQLDatabase(
|
||||
db_name, **database_config)
|
||||
stat_logger.info('init mysql database on cluster mode successfully')
|
||||
|
||||
|
||||
@ -263,7 +286,8 @@ class DatabaseLock:
|
||||
|
||||
def lock(self):
|
||||
# SQL parameters only support %s format placeholders
|
||||
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
|
||||
cursor = self.db.execute_sql(
|
||||
"SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
|
||||
@ -273,10 +297,12 @@ class DatabaseLock:
|
||||
raise Exception(f'failed to acquire lock {self.lock_name}')
|
||||
|
||||
def unlock(self):
|
||||
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
|
||||
cursor = self.db.execute_sql(
|
||||
"SELECT RELEASE_LOCK(%s)", (self.lock_name,))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
|
||||
raise Exception(
|
||||
f'mysql lock {self.lock_name} was not established by this thread')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
@ -350,17 +376,37 @@ class User(DataBaseModel, UserMixin):
|
||||
access_token = CharField(max_length=255, null=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
||||
password = CharField(max_length=255, null=True, help_text="password")
|
||||
email = CharField(max_length=255, null=False, help_text="email", index=True)
|
||||
email = CharField(
|
||||
max_length=255,
|
||||
null=False,
|
||||
help_text="email",
|
||||
index=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
|
||||
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright")
|
||||
timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai")
|
||||
language = CharField(
|
||||
max_length=32,
|
||||
null=True,
|
||||
help_text="English|Chinese",
|
||||
default="Chinese")
|
||||
color_schema = CharField(
|
||||
max_length=32,
|
||||
null=True,
|
||||
help_text="Bright|Dark",
|
||||
default="Bright")
|
||||
timezone = CharField(
|
||||
max_length=64,
|
||||
null=True,
|
||||
help_text="Timezone",
|
||||
default="UTC+8\tAsia/Shanghai")
|
||||
last_login_time = DateTimeField(null=True)
|
||||
is_authenticated = CharField(max_length=1, null=False, default="1")
|
||||
is_active = CharField(max_length=1, null=False, default="1")
|
||||
is_anonymous = CharField(max_length=1, null=False, default="0")
|
||||
login_channel = CharField(null=True, help_text="from which user login")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
||||
|
||||
def __str__(self):
|
||||
@ -379,12 +425,28 @@ class Tenant(DataBaseModel):
|
||||
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
||||
public_key = CharField(max_length=255, null=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
|
||||
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
|
||||
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||
parser_ids = CharField(max_length=256, null=False, help_text="document processors")
|
||||
embd_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default embedding model ID")
|
||||
asr_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default ASR model ID")
|
||||
img2txt_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default image to text model ID")
|
||||
parser_ids = CharField(
|
||||
max_length=256,
|
||||
null=False,
|
||||
help_text="document processors")
|
||||
credit = IntegerField(default=512)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant"
|
||||
@ -396,7 +458,11 @@ class UserTenant(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
||||
invited_by = CharField(max_length=32, null=False)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "user_tenant"
|
||||
@ -408,17 +474,32 @@ class InvitationCode(DataBaseModel):
|
||||
visit_time = DateTimeField(null=True)
|
||||
user_id = CharField(max_length=32, null=True)
|
||||
tenant_id = CharField(max_length=32, null=True)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "invitation_code"
|
||||
|
||||
|
||||
class LLMFactories(DataBaseModel):
|
||||
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
|
||||
name = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="LLM factory name",
|
||||
primary_key=True)
|
||||
logo = TextField(null=True, help_text="llm logo base64")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
tags = CharField(
|
||||
max_length=255,
|
||||
null=False,
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@ -429,12 +510,27 @@ class LLMFactories(DataBaseModel):
|
||||
|
||||
class LLM(DataBaseModel):
|
||||
# LLMs dictionary
|
||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True)
|
||||
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="LLM name",
|
||||
index=True,
|
||||
primary_key=True)
|
||||
model_type = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||
max_tokens = IntegerField(default=0)
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
tags = CharField(
|
||||
max_length=255,
|
||||
null=False,
|
||||
help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
@ -445,9 +541,19 @@ class LLM(DataBaseModel):
|
||||
|
||||
class TenantLLM(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
||||
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
|
||||
llm_factory = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="LLM factory name")
|
||||
model_type = CharField(
|
||||
max_length=128,
|
||||
null=True,
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(
|
||||
max_length=128,
|
||||
null=True,
|
||||
help_text="LLM name",
|
||||
default="")
|
||||
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
used_tokens = IntegerField(default=0)
|
||||
@ -464,11 +570,26 @@ class Knowledgebase(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
|
||||
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
|
||||
name = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="KB name",
|
||||
index=True)
|
||||
language = CharField(
|
||||
max_length=32,
|
||||
null=True,
|
||||
default="Chinese",
|
||||
help_text="English|Chinese")
|
||||
description = TextField(null=True, help_text="KB description")
|
||||
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
|
||||
permission = CharField(max_length=16, null=False, help_text="me|team", default="me")
|
||||
embd_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default embedding model ID")
|
||||
permission = CharField(
|
||||
max_length=16,
|
||||
null=False,
|
||||
help_text="me|team",
|
||||
default="me")
|
||||
created_by = CharField(max_length=32, null=False)
|
||||
doc_num = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
@ -476,9 +597,17 @@ class Knowledgebase(DataBaseModel):
|
||||
similarity_threshold = FloatField(default=0.2)
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
|
||||
parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
parser_id = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
help_text="default parser ID",
|
||||
default=ParserType.NAIVE.value)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@ -491,22 +620,50 @@ class Document(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
|
||||
parser_id = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
help_text="default parser ID")
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
source_type = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
default="local",
|
||||
help_text="where dose this document from")
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it")
|
||||
name = CharField(max_length=255, null=True, help_text="file name", index=True)
|
||||
location = CharField(max_length=255, null=True, help_text="where dose it store")
|
||||
created_by = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
help_text="who created it")
|
||||
name = CharField(
|
||||
max_length=255,
|
||||
null=True,
|
||||
help_text="file name",
|
||||
index=True)
|
||||
location = CharField(
|
||||
max_length=255,
|
||||
null=True,
|
||||
help_text="where dose it store")
|
||||
size = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = TextField(null=True, help_text="process message", default="")
|
||||
progress_msg = TextField(
|
||||
null=True,
|
||||
help_text="process message",
|
||||
default="")
|
||||
process_begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
run = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="start to run processing or cancel.(1: run it; 2: cancel)",
|
||||
default="0")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "document"
|
||||
@ -520,30 +677,52 @@ class Task(DataBaseModel):
|
||||
begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = TextField(null=True, help_text="process message", default="")
|
||||
progress_msg = TextField(
|
||||
null=True,
|
||||
help_text="process message",
|
||||
default="")
|
||||
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=255, null=True, help_text="dialog application name")
|
||||
name = CharField(
|
||||
max_length=255,
|
||||
null=True,
|
||||
help_text="dialog application name")
|
||||
description = TextField(null=True, help_text="Dialog description")
|
||||
icon = TextField(null=True, help_text="icon base64 string")
|
||||
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
|
||||
language = CharField(
|
||||
max_length=32,
|
||||
null=True,
|
||||
default="Chinese",
|
||||
help_text="English|Chinese")
|
||||
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
||||
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
||||
"presence_penalty": 0.4, "max_tokens": 215})
|
||||
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
|
||||
prompt_type = CharField(
|
||||
max_length=16,
|
||||
null=False,
|
||||
default="simple",
|
||||
help_text="simple|advanced")
|
||||
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
||||
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
||||
|
||||
similarity_threshold = FloatField(default=0.2)
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
top_n = IntegerField(default=6)
|
||||
do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1")
|
||||
do_refer = CharField(
|
||||
max_length=1,
|
||||
null=False,
|
||||
help_text="it needs to insert reference index into answer or not",
|
||||
default="1")
|
||||
|
||||
kb_ids = JSONField(null=False, default=[])
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog"
|
||||
|
||||
@ -32,8 +32,7 @@ LOGGER = getLogger()
|
||||
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
||||
DB.create_tables([model])
|
||||
|
||||
|
||||
for i,data in enumerate(data_source):
|
||||
for i, data in enumerate(data_source):
|
||||
current_time = current_timestamp() + i
|
||||
current_date = timestamp_to_date(current_time)
|
||||
if 'create_time' not in data:
|
||||
@ -55,7 +54,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
||||
|
||||
|
||||
def get_dynamic_db_model(base, job_id):
|
||||
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
|
||||
return type(base.model(
|
||||
table_index=get_dynamic_tracking_table_index(job_id=job_id)))
|
||||
|
||||
|
||||
def get_dynamic_tracking_table_index(job_id):
|
||||
@ -86,7 +86,9 @@ supported_operators = {
|
||||
'~': operator.inv,
|
||||
}
|
||||
|
||||
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
|
||||
|
||||
def query_dict2expression(
|
||||
model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
|
||||
expression = []
|
||||
|
||||
for field, value in query.items():
|
||||
@ -95,7 +97,10 @@ def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[boo
|
||||
op, *val = value
|
||||
|
||||
field = getattr(model, f'f_{field}')
|
||||
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
|
||||
value = supported_operators[op](
|
||||
field, val[0]) if op in supported_operators else getattr(
|
||||
field, op)(
|
||||
*val)
|
||||
expression.append(value)
|
||||
|
||||
return reduce(operator.iand, expression)
|
||||
|
||||
@ -61,45 +61,54 @@ def init_superuser():
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
||||
print(
|
||||
"【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
msg = chat_mdl.chat(system="", history=[
|
||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
if msg.find("ERROR: ") == 0:
|
||||
print("\33[91m【ERROR】\33[0m: ", "'{}' dosen't work. {}".format(tenant["llm_id"], msg))
|
||||
print(
|
||||
"\33[91m【ERROR】\33[0m: ",
|
||||
"'{}' dosen't work. {}".format(
|
||||
tenant["llm_id"],
|
||||
msg))
|
||||
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
|
||||
v, c = embd_mdl.encode(["Hello!"])
|
||||
if c == 0:
|
||||
print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"]))
|
||||
print(
|
||||
"\33[91m【ERROR】\33[0m:",
|
||||
" '{}' dosen't work!".format(
|
||||
tenant["embd_id"]))
|
||||
|
||||
|
||||
factory_infos = [{
|
||||
"name": "OpenAI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"name": "OpenAI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "Tongyi-Qianwen",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "ZHIPU-AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},
|
||||
{
|
||||
"name": "Local",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "Tongyi-Qianwen",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "ZHIPU-AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},
|
||||
{
|
||||
"name": "Local",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
}, {
|
||||
"name": "Moonshot",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
}
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
}
|
||||
# {
|
||||
# "name": "文心一言",
|
||||
# "logo": "",
|
||||
@ -107,6 +116,8 @@ factory_infos = [{
|
||||
# "status": "1",
|
||||
# },
|
||||
]
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
llm_infos = [
|
||||
# ---------------------- OpenAI ------------------------
|
||||
@ -116,37 +127,37 @@ def init_llm_factory():
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo-16k-0613",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16385,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "text-embedding-ada-002",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "whisper-1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"max_tokens": 25*1024*1024,
|
||||
"max_tokens": 25 * 1024 * 1024,
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-32k",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-vision-preview",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
@ -160,31 +171,31 @@ def init_llm_factory():
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-plus",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-max-1201",
|
||||
"tags": "LLM,CHAT,6K",
|
||||
"max_tokens": 5899,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "text-embedding-v2",
|
||||
"tags": "TEXT EMBEDDING,2K",
|
||||
"max_tokens": 2048,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "paraformer-realtime-8k-v1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"max_tokens": 25*1024*1024,
|
||||
"max_tokens": 25 * 1024 * 1024,
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-vl-max",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
@ -245,13 +256,13 @@ def init_llm_factory():
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-32k",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
}, {
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-128k",
|
||||
"tags": "LLM,CHAT",
|
||||
@ -294,7 +305,6 @@ def init_web_data():
|
||||
print("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_web_db()
|
||||
init_web_data()
|
||||
init_web_data()
|
||||
|
||||
@ -18,4 +18,4 @@ import operator
|
||||
import time
|
||||
import typing
|
||||
from api.utils.log_utils import sql_logger
|
||||
import peewee
|
||||
import peewee
|
||||
|
||||
@ -18,10 +18,11 @@ class ReloadConfigBase:
|
||||
def get_all(cls):
|
||||
configs = {}
|
||||
for k, v in cls.__dict__.items():
|
||||
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
|
||||
if not callable(getattr(cls, k)) and not k.startswith(
|
||||
"__") and not k.startswith("_"):
|
||||
configs[k] = v
|
||||
return configs
|
||||
|
||||
@classmethod
|
||||
def get(cls, config_name):
|
||||
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
||||
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
||||
|
||||
@ -51,4 +51,4 @@ class RuntimeConfig(ReloadConfigBase):
|
||||
|
||||
@classmethod
|
||||
def set_service_db(cls, service_db):
|
||||
cls.SERVICE_DB = service_db
|
||||
cls.SERVICE_DB = service_db
|
||||
|
||||
@ -27,7 +27,8 @@ class CommonService:
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
return cls.model.query(cols=cols, reverse=reverse,
|
||||
order_by=order_by, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -40,9 +41,11 @@ class CommonService:
|
||||
if not order_by or not hasattr(cls, order_by):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||
query_records = query_records.order_by(
|
||||
cls.model.getter_by(order_by).desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
||||
query_records = query_records.order_by(
|
||||
cls.model.getter_by(order_by).asc())
|
||||
return query_records
|
||||
|
||||
@classmethod
|
||||
@ -61,7 +64,7 @@ class CommonService:
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
#if "id" not in kwargs:
|
||||
# if "id" not in kwargs:
|
||||
# kwargs["id"] = get_uuid()
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
@ -95,7 +98,8 @@ class CommonService:
|
||||
for data in data_list:
|
||||
data["update_time"] = current_timestamp()
|
||||
data["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
cls.model.update(data).where(
|
||||
cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -128,7 +132,6 @@ class CommonService:
|
||||
def delete_by_id(cls, pid):
|
||||
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_delete(cls, filters):
|
||||
@ -151,19 +154,30 @@ class CommonService:
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
||||
def filter_scope_list(cls, in_key, in_filters_list,
|
||||
filters=None, cols=None):
|
||||
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||
if not filters:
|
||||
filters = []
|
||||
res_list = []
|
||||
if cols:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
query_records = cls.model.select(
|
||||
*
|
||||
cols).where(
|
||||
getattr(
|
||||
cls.model,
|
||||
in_key).in_(i),
|
||||
*
|
||||
filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
res_list.extend(
|
||||
[query_record for query_record in query_records])
|
||||
else:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
query_records = cls.model.select().where(
|
||||
getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
return res_list
|
||||
res_list.extend(
|
||||
[query_record for query_record in query_records])
|
||||
return res_list
|
||||
|
||||
@ -21,6 +21,5 @@ class DialogService(CommonService):
|
||||
model = Dialog
|
||||
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
model = Conversation
|
||||
|
||||
@ -72,7 +72,20 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.kb_id,
|
||||
cls.model.parser_id,
|
||||
cls.model.parser_config,
|
||||
cls.model.name,
|
||||
cls.model.type,
|
||||
cls.model.location,
|
||||
cls.model.size,
|
||||
Knowledgebase.tenant_id,
|
||||
Tenant.embd_id,
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
||||
@ -103,40 +116,64 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duation=cls.model.process_duation+duation).where(
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duation=cls.model.process_duation + duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num +
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num +
|
||||
chunk_num).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:return
|
||||
if not docs:
|
||||
return
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_thumbnails(cls, docids):
|
||||
fields = [cls.model.id, cls.model.thumbnail]
|
||||
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
return list(cls.model.select(
|
||||
*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_parser_config(cls, id, config):
|
||||
e, d = cls.get_by_id(id)
|
||||
if not e:raise LookupError(f"Document({id}) not found.")
|
||||
if not e:
|
||||
raise LookupError(f"Document({id}) not found.")
|
||||
|
||||
def dfs_update(old, new):
|
||||
for k,v in new.items():
|
||||
for k, v in new.items():
|
||||
if k not in old:
|
||||
old[k] = v
|
||||
continue
|
||||
if isinstance(v, dict):
|
||||
assert isinstance(old[k], dict)
|
||||
dfs_update(old[k], v)
|
||||
else: old[k] = v
|
||||
else:
|
||||
old[k] = v
|
||||
dfs_update(d.parser_config, config)
|
||||
cls.update_by_id(id, {"parser_config": d.parser_config})
|
||||
cls.update_by_id(id, {"parser_config": d.parser_config})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_count(cls, tenant_id):
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase,
|
||||
on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
Knowledgebase.tenant_id == tenant_id)
|
||||
return len(docs)
|
||||
|
||||
@ -55,7 +55,7 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.parser_config]
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
@ -69,9 +69,11 @@ class KnowledgebaseService(CommonService):
|
||||
@DB.connection_context()
|
||||
def update_parser_config(cls, id, config):
|
||||
e, m = cls.get_by_id(id)
|
||||
if not e:raise LookupError(f"knowledgebase({id}) not found.")
|
||||
if not e:
|
||||
raise LookupError(f"knowledgebase({id}) not found.")
|
||||
|
||||
def dfs_update(old, new):
|
||||
for k,v in new.items():
|
||||
for k, v in new.items():
|
||||
if k not in old:
|
||||
old[k] = v
|
||||
continue
|
||||
@ -80,12 +82,12 @@ class KnowledgebaseService(CommonService):
|
||||
dfs_update(old[k], v)
|
||||
elif isinstance(v, list):
|
||||
assert isinstance(old[k], list)
|
||||
old[k] = list(set(old[k]+v))
|
||||
else: old[k] = v
|
||||
old[k] = list(set(old[k] + v))
|
||||
else:
|
||||
old[k] = v
|
||||
dfs_update(m.parser_config, config)
|
||||
cls.update_by_id(id, {"parser_config": m.parser_config})
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_field_map(cls, ids):
|
||||
@ -94,4 +96,3 @@ class KnowledgebaseService(CommonService):
|
||||
if k.parser_config and "field_map" in k.parser_config:
|
||||
conf.update(k.parser_config["field_map"])
|
||||
return conf
|
||||
|
||||
|
||||
@ -59,7 +59,8 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||
def model_instance(cls, tenant_id, llm_type,
|
||||
llm_name=None, lang="Chinese"):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
@ -126,29 +127,39 @@ class LLMBundle(object):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
|
||||
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
self.mdl = TenantLLMService.model_instance(
|
||||
tenant_id, llm_type, llm_name, lang=lang)
|
||||
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
||||
tenant_id, llm_type, llm_name)
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
||||
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||
if TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error(
|
||||
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||
return emd, used_tokens
|
||||
|
||||
def encode_queries(self, query: str):
|
||||
emd, used_tokens = self.mdl.encode_queries(query)
|
||||
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||
if TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error(
|
||||
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||
return emd, used_tokens
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error(
|
||||
"Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
|
||||
return txt
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
||||
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
|
||||
if TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
database_logger.error(
|
||||
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
||||
return txt
|
||||
|
||||
@ -54,7 +54,8 @@ class UserService(CommonService):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
if "password" in kwargs:
|
||||
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
||||
kwargs["password"] = generate_password_hash(
|
||||
str(kwargs["password"]))
|
||||
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
@ -63,12 +64,12 @@ class UserService(CommonService):
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_user(cls, user_ids, update_user_dict):
|
||||
with DB.atomic():
|
||||
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
|
||||
cls.model.update({"status": 0}).where(
|
||||
cls.model.id.in_(user_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -77,7 +78,8 @@ class UserService(CommonService):
|
||||
if user_dict:
|
||||
user_dict["update_time"] = current_timestamp()
|
||||
user_dict["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
||||
cls.model.update(user_dict).where(
|
||||
cls.model.id == user_id).execute()
|
||||
|
||||
|
||||
class TenantService(CommonService):
|
||||
@ -86,25 +88,42 @@ class TenantService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_user_id(cls, user_id):
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
|
||||
return list(cls.model.select(*fields)\
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
fields = [
|
||||
cls.model.id.alias("tenant_id"),
|
||||
cls.model.name,
|
||||
cls.model.llm_id,
|
||||
cls.model.embd_id,
|
||||
cls.model.asr_id,
|
||||
cls.model.img2txt_id,
|
||||
cls.model.parser_ids,
|
||||
UserTenant.role]
|
||||
return list(cls.model.select(*fields)
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_joined_tenants_by_user_id(cls, user_id):
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
|
||||
return list(cls.model.select(*fields)\
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
fields = [
|
||||
cls.model.id.alias("tenant_id"),
|
||||
cls.model.name,
|
||||
cls.model.llm_id,
|
||||
cls.model.embd_id,
|
||||
cls.model.asr_id,
|
||||
cls.model.img2txt_id,
|
||||
UserTenant.role]
|
||||
return list(cls.model.select(*fields)
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL.value)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrease(cls, user_id, num):
|
||||
num = cls.model.update(credit=cls.model.credit - num).where(
|
||||
cls.model.id == user_id).execute()
|
||||
if num == 0: raise LookupError("Tenant not found which is supposed to be there")
|
||||
if num == 0:
|
||||
raise LookupError("Tenant not found which is supposed to be there")
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
model = UserTenant
|
||||
|
||||
@ -13,16 +13,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from rag.utils import ELASTICSEARCH
|
||||
from rag.nlp import search
|
||||
import os
|
||||
|
||||
from enum import IntEnum, Enum
|
||||
|
||||
from api.utils import get_base_config,decrypt_database_config
|
||||
from api.utils import get_base_config, decrypt_database_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from api.utils.log_utils import LoggerFactory, getLogger
|
||||
|
||||
# Logger
|
||||
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api"))
|
||||
LoggerFactory.set_directory(
|
||||
os.path.join(
|
||||
get_project_base_directory(),
|
||||
"logs",
|
||||
"api"))
|
||||
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
||||
LoggerFactory.LEVEL = 10
|
||||
|
||||
@ -86,7 +92,9 @@ default_llm = {
|
||||
LLM = get_base_config("user_default_llm", {})
|
||||
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
||||
if LLM_FACTORY not in default_llm:
|
||||
print("\33[91m【ERROR】\33[0m:", f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
||||
print(
|
||||
"\33[91m【ERROR】\33[0m:",
|
||||
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
||||
LLM_FACTORY = "Tongyi-Qianwen"
|
||||
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
||||
EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
|
||||
@ -94,7 +102,9 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||
|
||||
API_KEY = LLM.get("api_key", "")
|
||||
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
|
||||
PARSERS = LLM.get(
|
||||
"parsers",
|
||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
|
||||
|
||||
# distribution
|
||||
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
||||
@ -103,13 +113,25 @@ RAG_FLOW_UPDATE_CHECK = False
|
||||
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||
|
||||
SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow")
|
||||
TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600)
|
||||
SECRET_KEY = get_base_config(
|
||||
RAG_FLOW_SERVICE_NAME,
|
||||
{}).get(
|
||||
"secret_key",
|
||||
"infiniflow")
|
||||
TOKEN_EXPIRE_IN = get_base_config(
|
||||
RAG_FLOW_SERVICE_NAME, {}).get(
|
||||
"token_expires_in", 3600)
|
||||
|
||||
NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST
|
||||
NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT
|
||||
NGINX_HOST = get_base_config(
|
||||
RAG_FLOW_SERVICE_NAME, {}).get(
|
||||
"nginx", {}).get("host") or HOST
|
||||
NGINX_HTTP_PORT = get_base_config(
|
||||
RAG_FLOW_SERVICE_NAME, {}).get(
|
||||
"nginx", {}).get("http_port") or HTTP_PORT
|
||||
|
||||
RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False)
|
||||
RANDOM_INSTANCE_ID = get_base_config(
|
||||
RAG_FLOW_SERVICE_NAME, {}).get(
|
||||
"random_instance_id", False)
|
||||
|
||||
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
|
||||
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
|
||||
@ -124,7 +146,9 @@ UPLOAD_DATA_FROM_CLIENT = True
|
||||
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
||||
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
|
||||
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
||||
"client", {}).get(
|
||||
"switch", False)
|
||||
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
|
||||
@ -147,12 +171,10 @@ USE_AUTHENTICATION = False
|
||||
USE_DATA_AUTHENTICATION = False
|
||||
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
|
||||
USE_DEFAULT_TIMEOUT = False
|
||||
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
||||
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
||||
PRIVILEGE_COMMAND_WHITELIST = []
|
||||
CHECK_NODES_IDENTITY = False
|
||||
|
||||
from rag.nlp import search
|
||||
from rag.utils import ELASTICSEARCH
|
||||
retrievaler = search.Dealer(ELASTICSEARCH)
|
||||
|
||||
|
||||
@ -162,7 +184,7 @@ class CustomEnum(Enum):
|
||||
try:
|
||||
cls(value)
|
||||
return True
|
||||
except:
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -34,10 +34,12 @@ from . import file_utils
|
||||
|
||||
SERVICE_CONF = "service_conf.yaml"
|
||||
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
|
||||
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
@ -62,7 +64,8 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
||||
return config.get(key, default) if key is not None else config
|
||||
|
||||
|
||||
use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False)
|
||||
use_deserialize_safe_module = get_base_config(
|
||||
'use_deserialize_safe_module', False)
|
||||
|
||||
|
||||
class CoordinationCommunicationProtocol(object):
|
||||
@ -93,7 +96,8 @@ class BaseType:
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__, "data": data, "module": module}
|
||||
return {"type": obj.__class__.__name__,
|
||||
"data": data, "module": module}
|
||||
return _dict(self)
|
||||
|
||||
|
||||
@ -129,7 +133,8 @@ def rag_uuid():
|
||||
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
@ -137,7 +142,11 @@ def bytes_to_string(byte):
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
|
||||
dest = json.dumps(
|
||||
src,
|
||||
indent=indent,
|
||||
cls=CustomJSONEncoder,
|
||||
with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
@ -146,7 +155,8 @@ def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)
|
||||
return json.loads(src, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook)
|
||||
|
||||
|
||||
def current_timestamp():
|
||||
@ -177,7 +187,9 @@ def serialize_b64(src, to_str=False):
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
|
||||
src = base64.b64decode(
|
||||
string_to_bytes(src) if isinstance(
|
||||
src, str) else src)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
@ -237,12 +249,14 @@ def get_lan_ip():
|
||||
pass
|
||||
return ip or ''
|
||||
|
||||
|
||||
def from_dict_hook(in_dict: dict):
|
||||
if "type" in in_dict and "data" in in_dict:
|
||||
if in_dict["module"] is None:
|
||||
return in_dict["data"]
|
||||
else:
|
||||
return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
||||
return getattr(importlib.import_module(
|
||||
in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
||||
else:
|
||||
return in_dict
|
||||
|
||||
@ -259,12 +273,16 @@ def decrypt_database_password(password):
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1])
|
||||
pwdecrypt_fun = getattr(
|
||||
importlib.import_module(
|
||||
module_fun[0]),
|
||||
module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(database=None, passwd_key="password", name="database"):
|
||||
def decrypt_database_config(
|
||||
database=None, passwd_key="password", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
@ -275,7 +293,8 @@ def decrypt_database_config(database=None, passwd_key="password", name="database
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
conf_path = os.path.join(
|
||||
file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
@ -288,7 +307,8 @@ def get_uuid():
|
||||
|
||||
|
||||
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
|
||||
return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second)
|
||||
return datetime.datetime(date_time.year, date_time.month, date_time.day,
|
||||
date_time.hour, date_time.minute, date_time.second)
|
||||
|
||||
|
||||
def get_format_time() -> datetime.datetime:
|
||||
@ -307,14 +327,19 @@ def elapsed2time(elapsed):
|
||||
|
||||
|
||||
def decrypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"private.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||
return cipher.decrypt(base64.b64decode(
|
||||
line), "Fail to decrypt password!").decode('utf-8')
|
||||
|
||||
|
||||
def download_img(url):
|
||||
if not url: return ""
|
||||
if not url:
|
||||
return ""
|
||||
response = requests.get(url)
|
||||
return "data:" + \
|
||||
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
||||
|
||||
@ -19,7 +19,7 @@ import time
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from flask import (
|
||||
Response, jsonify, send_file,make_response,
|
||||
Response, jsonify, send_file, make_response,
|
||||
request as flask_request,
|
||||
)
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
@ -29,7 +29,7 @@ from api.versions import get_rag_version
|
||||
from api.settings import RetCode
|
||||
from api.settings import (
|
||||
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
||||
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
||||
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
||||
)
|
||||
import requests
|
||||
import functools
|
||||
@ -40,14 +40,21 @@ from hmac import HMAC
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
requests.models.complexjson.dumps = functools.partial(
|
||||
json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
def request(**kwargs):
|
||||
sess = requests.Session()
|
||||
stream = kwargs.pop('stream', sess.stream)
|
||||
timeout = kwargs.pop('timeout', None)
|
||||
kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()}
|
||||
kwargs['headers'] = {
|
||||
k.replace(
|
||||
'_',
|
||||
'-').upper(): v for k,
|
||||
v in kwargs.get(
|
||||
'headers',
|
||||
{}).items()}
|
||||
prepped = requests.Request(**kwargs).prepare()
|
||||
|
||||
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
||||
@ -59,7 +66,11 @@ def request(**kwargs):
|
||||
HTTP_APP_KEY.encode('ascii'),
|
||||
prepped.path_url.encode('ascii'),
|
||||
prepped.body if kwargs.get('json') else b'',
|
||||
urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii')
|
||||
urlencode(
|
||||
sorted(
|
||||
kwargs['data'].items()),
|
||||
quote_via=quote,
|
||||
safe='-._~').encode('ascii')
|
||||
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
|
||||
]), 'sha1').digest()).decode('ascii')
|
||||
|
||||
@ -88,11 +99,12 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||
return max(0, countdown)
|
||||
|
||||
|
||||
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None):
|
||||
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
|
||||
data=None, job_id=None, meta=None):
|
||||
import re
|
||||
result_dict = {
|
||||
"retcode": retcode,
|
||||
"retmsg":retmsg,
|
||||
"retmsg": retmsg,
|
||||
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
|
||||
"data": data,
|
||||
"jobId": job_id,
|
||||
@ -107,9 +119,17 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
|
||||
def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'):
|
||||
|
||||
def get_data_error_result(retcode=RetCode.DATA_ERROR,
|
||||
retmsg='Sorry! Data missing!'):
|
||||
import re
|
||||
result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)}
|
||||
result_dict = {
|
||||
"retcode": retcode,
|
||||
"retmsg": re.sub(
|
||||
r"rag",
|
||||
"seceum",
|
||||
retmsg,
|
||||
flags=re.IGNORECASE)}
|
||||
response = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "retcode":
|
||||
@ -118,15 +138,17 @@ def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missin
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def server_error_response(e):
|
||||
stat_logger.exception(e)
|
||||
try:
|
||||
if e.code==401:
|
||||
if e.code == 401:
|
||||
return get_json_result(retcode=401, retmsg=repr(e))
|
||||
except:
|
||||
except BaseException:
|
||||
pass
|
||||
if len(e.args) > 1:
|
||||
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
|
||||
return get_json_result(
|
||||
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
|
||||
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
|
||||
|
||||
|
||||
@ -162,10 +184,13 @@ def validate_request(*args, **kwargs):
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
error_string += "required argument are missing: {}; ".format(
|
||||
",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
|
||||
error_string += "required argument values: {}".format(
|
||||
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(
|
||||
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
|
||||
return func(*_args, **_kwargs)
|
||||
return decorated_function
|
||||
return wrapper
|
||||
@ -193,7 +218,8 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None):
|
||||
def cors_reponse(retcode=RetCode.SUCCESS,
|
||||
retmsg='success', data=None, auth=None):
|
||||
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
@ -209,4 +235,4 @@ def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
||||
return response
|
||||
return response
|
||||
|
||||
@ -29,6 +29,7 @@ from api.db import FileType
|
||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
RAG_BASE = os.getenv("RAG_BASE")
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
if PROJECT_BASE is None:
|
||||
@ -65,7 +66,6 @@ def get_rag_python_directory(*args):
|
||||
return get_rag_directory("python", *args)
|
||||
|
||||
|
||||
|
||||
@cached(cache=LRUCache(maxsize=10))
|
||||
def load_json_conf(conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
@ -146,10 +146,12 @@ def filename_type(filename):
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
return FileType.PDF.value
|
||||
|
||||
if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
||||
if re.match(
|
||||
r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
||||
return FileType.DOC.value
|
||||
|
||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
||||
if re.match(
|
||||
r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
||||
return FileType.AURAL.value
|
||||
|
||||
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
||||
@ -164,14 +166,16 @@ def thumbnail(filename, blob):
|
||||
buffered = BytesIO()
|
||||
Image.frombytes("RGB", [pix.width, pix.height],
|
||||
pix.samples).save(buffered, format="png")
|
||||
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return "data:image/png;base64," + \
|
||||
base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
|
||||
image = Image.open(BytesIO(blob))
|
||||
image.thumbnail((30, 30))
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="png")
|
||||
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return "data:image/png;base64," + \
|
||||
base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
if re.match(r".*\.(ppt|pptx)$", filename):
|
||||
import aspose.slides as slides
|
||||
@ -179,8 +183,10 @@ def thumbnail(filename, blob):
|
||||
try:
|
||||
with slides.Presentation(BytesIO(blob)) as presentation:
|
||||
buffered = BytesIO()
|
||||
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
|
||||
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
presentation.slides[0].get_thumbnail(0.03, 0.03).save(
|
||||
buffered, drawing.imaging.ImageFormat.png)
|
||||
return "data:image/png;base64," + \
|
||||
base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -190,6 +196,3 @@ def traversal_files(base):
|
||||
for f in fs:
|
||||
fullname = os.path.join(root, f)
|
||||
yield fullname
|
||||
|
||||
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from threading import RLock
|
||||
|
||||
from api.utils import file_utils
|
||||
|
||||
|
||||
class LoggerFactory(object):
|
||||
TYPE = "FILE"
|
||||
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
|
||||
@ -49,7 +50,8 @@ class LoggerFactory(object):
|
||||
schedule_logger_dict = {}
|
||||
|
||||
@staticmethod
|
||||
def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False):
|
||||
def set_directory(directory=None, parent_log_dir=None,
|
||||
append_to_parent_log=None, force=False):
|
||||
if parent_log_dir:
|
||||
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
|
||||
if append_to_parent_log:
|
||||
@ -66,11 +68,13 @@ class LoggerFactory(object):
|
||||
else:
|
||||
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
|
||||
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
|
||||
for className, (logger, handler) in LoggerFactory.logger_dict.items():
|
||||
for className, (logger,
|
||||
handler) in LoggerFactory.logger_dict.items():
|
||||
logger.removeHandler(ghandler)
|
||||
ghandler.close()
|
||||
LoggerFactory.global_handler_dict = {}
|
||||
for className, (logger, handler) in LoggerFactory.logger_dict.items():
|
||||
for className, (logger,
|
||||
handler) in LoggerFactory.logger_dict.items():
|
||||
logger.removeHandler(handler)
|
||||
_handler = None
|
||||
if handler:
|
||||
@ -111,19 +115,23 @@ class LoggerFactory(object):
|
||||
if logger_name_key not in LoggerFactory.global_handler_dict:
|
||||
with LoggerFactory.lock:
|
||||
if logger_name_key not in LoggerFactory.global_handler_dict:
|
||||
handler = LoggerFactory.get_handler(logger_name, level, log_dir)
|
||||
handler = LoggerFactory.get_handler(
|
||||
logger_name, level, log_dir)
|
||||
LoggerFactory.global_handler_dict[logger_name_key] = handler
|
||||
return LoggerFactory.global_handler_dict[logger_name_key]
|
||||
|
||||
@staticmethod
|
||||
def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None):
|
||||
def get_handler(class_name, level=None, log_dir=None,
|
||||
log_type=None, job_id=None):
|
||||
if not log_type:
|
||||
if not LoggerFactory.LOG_DIR or not class_name:
|
||||
return logging.StreamHandler()
|
||||
# return Diy_StreamHandler()
|
||||
|
||||
if not log_dir:
|
||||
log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name))
|
||||
log_file = os.path.join(
|
||||
LoggerFactory.LOG_DIR,
|
||||
"{}.log".format(class_name))
|
||||
else:
|
||||
log_file = os.path.join(log_dir, "{}.log".format(class_name))
|
||||
else:
|
||||
@ -133,16 +141,16 @@ class LoggerFactory(object):
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
if LoggerFactory.log_share:
|
||||
handler = ROpenHandler(log_file,
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
else:
|
||||
handler = TimedRotatingFileHandler(log_file,
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
if level:
|
||||
handler.level = level
|
||||
|
||||
@ -170,7 +178,9 @@ class LoggerFactory(object):
|
||||
for level in LoggerFactory.levels:
|
||||
if level >= LoggerFactory.LEVEL:
|
||||
level_logger_name = logging._levelToName[level]
|
||||
logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level))
|
||||
logger.addHandler(
|
||||
LoggerFactory.get_global_handler(
|
||||
level_logger_name, level))
|
||||
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
|
||||
for level in LoggerFactory.levels:
|
||||
if level >= LoggerFactory.LEVEL:
|
||||
@ -224,22 +234,26 @@ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
return f"{prefix}start to {msg}{suffix}"
|
||||
|
||||
|
||||
def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
def successful_log(msg, job=None, task=None, role=None,
|
||||
party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} successfully{suffix}"
|
||||
|
||||
|
||||
def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
def warning_log(msg, job=None, task=None, role=None,
|
||||
party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} is not effective{suffix}"
|
||||
|
||||
|
||||
def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
def failed_log(msg, job=None, task=None, role=None,
|
||||
party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}failed to {msg}{suffix}"
|
||||
|
||||
|
||||
def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None):
|
||||
def base_msg(job=None, task=None, role: str = None,
|
||||
party_id: typing.Union[str, int] = None, detail=None):
|
||||
if detail:
|
||||
detail_msg = f" detail: \n{detail}"
|
||||
else:
|
||||
@ -285,10 +299,14 @@ def get_job_logger(job_id, log_type):
|
||||
for job_log_dir in log_dirs:
|
||||
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
|
||||
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
||||
error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
||||
error_handler = LoggerFactory.get_handler(
|
||||
class_name=None,
|
||||
level=logging.ERROR,
|
||||
log_dir=job_log_dir,
|
||||
log_type=log_type,
|
||||
job_id=job_id)
|
||||
logger.addHandler(handler)
|
||||
logger.addHandler(error_handler)
|
||||
with LoggerFactory.lock:
|
||||
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
|
||||
return logger
|
||||
|
||||
|
||||
@ -1,18 +1,23 @@
|
||||
import base64, os, sys
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from api.utils import decrypt, file_utils
|
||||
|
||||
|
||||
def crypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read())
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8")
|
||||
|
||||
return base64.b64encode(cipher.encrypt(
|
||||
line.encode('utf-8'))).decode("utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pswd = crypt(sys.argv[1])
|
||||
print(pswd)
|
||||
print(decrypt(pswd))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user