mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
apply pep8 formalize (#155)
This commit is contained in:
@ -121,7 +121,9 @@ def get():
|
||||
"important_kwd")
|
||||
def set():
|
||||
req = request.json
|
||||
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = huqie.qie(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
@ -140,10 +142,16 @@ def set():
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1]
|
||||
if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.")
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(
|
||||
retmsg="Q&A must be separated by TAB/ENTER key.")
|
||||
q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
|
||||
d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q + a]))
|
||||
d = beAdoc(d, arr[0], arr[1], not any(
|
||||
[huqie.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
@ -177,7 +185,8 @@ def switch():
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
||||
if not ELASTICSEARCH.deleteByQuery(
|
||||
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
||||
return get_data_error_result(retmsg="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
||||
@ -100,7 +100,10 @@ def rm():
|
||||
def list_convsersation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
|
||||
convs = ConversationService.query(
|
||||
dialog_id=dialog_id,
|
||||
order_by=ConversationService.model.create_time,
|
||||
reverse=True)
|
||||
convs = [d.to_dict() for d in convs]
|
||||
return get_json_result(data=convs)
|
||||
except Exception as e:
|
||||
@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000):
|
||||
def count():
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
for m in msg:
|
||||
tks_cnts.append(
|
||||
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts: total += m["count"]
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
return total
|
||||
|
||||
c = count()
|
||||
if c < max_length: return c, msg
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
||||
msg_.append(msg[-1])
|
||||
msg = msg_
|
||||
c = count()
|
||||
if c < max_length: return c, msg
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
ll = num_tokens_from_string(msg_[0].content)
|
||||
l = num_tokens_from_string(msg_[-1].content)
|
||||
@ -146,8 +154,10 @@ def completion():
|
||||
req = request.json
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system": continue
|
||||
if m["role"] == "assistant" and not msg: continue
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append({"role": m["role"], "content": m["content"]})
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
@ -160,7 +170,8 @@ def completion():
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
ans = chat(dia, msg, **req)
|
||||
if not conv.reference: conv.reference = []
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.reference.append(ans["reference"])
|
||||
conv.message.append({"role": "assistant", "content": ans["answer"]})
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs):
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
## try to use sql if field mapping is good to go
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge": continue
|
||||
if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"])
|
||||
if p["key"] == "knowledge":
|
||||
continue
|
||||
if p["key"] not in kwargs and not p["optional"]:
|
||||
raise KeyError("Miss parameter: " + p["key"])
|
||||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||
prompt_config["system"] = prompt_config["system"].replace(
|
||||
"{%s}" % p["key"], " ")
|
||||
|
||||
for _ in range(len(questions)//2):
|
||||
for _ in range(len(questions) // 2):
|
||||
questions.append(questions[-1])
|
||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||
kbinfos = {"total":0, "chunks":[],"doc_aggs":[]}
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
else:
|
||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
chat_logger.info(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
return {
|
||||
"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
|
||||
kwargs["knowledge"] = "\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
|
||||
msg = [{"role": m["role"], "content": m["content"]}
|
||||
for m in messages if m["role"] != "system"]
|
||||
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
|
||||
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
||||
chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
|
||||
gen_conf["max_tokens"] = min(
|
||||
gen_conf["max_tokens"],
|
||||
llm.max_tokens - used_token_count)
|
||||
answer = chat_mdl.chat(
|
||||
prompt_config["system"].format(
|
||||
**kwargs), msg, gen_conf)
|
||||
chat_logger.info("User: {}|Assistant: {}".format(
|
||||
msg[-1]["content"], answer))
|
||||
|
||||
if knowledges:
|
||||
answer, idx = retrievaler.insert_citations(answer,
|
||||
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
||||
[ck["vector"] for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
kbinfos["doc_aggs"] = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
kbinfos["doc_aggs"] = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
for c in kbinfos["chunks"]:
|
||||
if c.get("vector"): del c["vector"]
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
return {"answer": answer, "reference": kbinfos}
|
||||
|
||||
|
||||
@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
question
|
||||
)
|
||||
tried_times = 0
|
||||
|
||||
def get_table():
|
||||
nonlocal sys_prompt, user_promt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
||||
"temperature": 0.06})
|
||||
print(user_promt, sql)
|
||||
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
@ -262,8 +290,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:continue
|
||||
if len(flds) > 11:break
|
||||
if k in forbidden_select_fields4resume:
|
||||
continue
|
||||
if len(flds) > 11:
|
||||
break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
|
||||
@ -284,13 +314,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
|
||||
问题如下:
|
||||
{}
|
||||
|
||||
|
||||
你上一次给出的错误SQL如下:
|
||||
{}
|
||||
|
||||
|
||||
后台报错如下:
|
||||
{}
|
||||
|
||||
|
||||
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
@ -302,16 +332,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
|
||||
chat_logger.info("GET table: {}".format(tbl))
|
||||
print(tbl)
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||
return None, None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
||||
docid_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(
|
||||
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
||||
|
||||
# compose markdown table
|
||||
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
||||
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
||||
("|------|" if docid_idx and docid_idx else "")
|
||||
rows = ["|" +
|
||||
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
||||
"|" for r in tbl["rows"]]
|
||||
if not docid_idx or not docnm_idx:
|
||||
chat_logger.warning("SQL missing field: " + sql)
|
||||
return "\n".join([clmns, line, "\n".join(rows)]), []
|
||||
@ -328,5 +366,5 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
return {
|
||||
"answer": "\n".join([clmns, line, rows]),
|
||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
||||
"doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
|
||||
}
|
||||
|
||||
@ -55,7 +55,8 @@ def set_dialog():
|
||||
}
|
||||
prompt_config = req.get("prompt_config", default_prompt)
|
||||
|
||||
if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"]
|
||||
if not prompt_config["system"]:
|
||||
prompt_config["system"] = default_prompt["system"]
|
||||
# if len(prompt_config["parameters"]) < 1:
|
||||
# prompt_config["parameters"] = default_prompt["parameters"]
|
||||
# for p in prompt_config["parameters"]:
|
||||
@ -63,16 +64,21 @@ def set_dialog():
|
||||
# else: prompt_config["parameters"].append(default_prompt["parameters"][0])
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["optional"]: continue
|
||||
if p["optional"]:
|
||||
continue
|
||||
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"]))
|
||||
return get_data_error_result(
|
||||
retmsg="Parameter '{}' is not used".format(p["key"]))
|
||||
|
||||
try:
|
||||
e, tenant = TenantService.get_by_id(current_user.id)
|
||||
if not e: return get_data_error_result(retmsg="Tenant not found!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
llm_id = req.get("llm_id", tenant.llm_id)
|
||||
if not dialog_id:
|
||||
if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!")
|
||||
if not req.get("kb_ids"):
|
||||
return get_data_error_result(
|
||||
retmsg="Fail! Please select knowledgebase!")
|
||||
dia = {
|
||||
"id": get_uuid(),
|
||||
"tenant_id": current_user.id,
|
||||
@ -86,17 +92,21 @@ def set_dialog():
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight
|
||||
}
|
||||
if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
if not DialogService.save(**dia):
|
||||
return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
e, dia = DialogService.get_by_id(dia["id"])
|
||||
if not e: return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Fail to new a dialog!")
|
||||
return get_json_result(data=dia.to_json())
|
||||
else:
|
||||
del req["dialog_id"]
|
||||
if "kb_names" in req: del req["kb_names"]
|
||||
if "kb_names" in req:
|
||||
del req["kb_names"]
|
||||
if not DialogService.update_by_id(dialog_id, req):
|
||||
return get_data_error_result(retmsg="Dialog not found!")
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
if not e: return get_data_error_result(retmsg="Fail to update a dialog!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Fail to update a dialog!")
|
||||
dia = dia.to_dict()
|
||||
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
||||
return get_json_result(data=dia)
|
||||
@ -110,7 +120,8 @@ def get():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
if not e: return get_data_error_result(retmsg="Dialog not found!")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Dialog not found!")
|
||||
dia = dia.to_dict()
|
||||
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
||||
return get_json_result(data=dia)
|
||||
@ -122,7 +133,8 @@ def get_kb_names(kb_ids):
|
||||
ids, nms = [], []
|
||||
for kid in kb_ids:
|
||||
e, kb = KnowledgebaseService.get_by_id(kid)
|
||||
if not e or kb.status != StatusEnum.VALID.value: continue
|
||||
if not e or kb.status != StatusEnum.VALID.value:
|
||||
continue
|
||||
ids.append(kid)
|
||||
nms.append(kb.name)
|
||||
return ids, nms
|
||||
@ -132,7 +144,11 @@ def get_kb_names(kb_ids):
|
||||
@login_required
|
||||
def list():
|
||||
try:
|
||||
diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time)
|
||||
diags = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
@ -147,7 +163,8 @@ def list():
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
|
||||
DialogService.update_many_by_id(
|
||||
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -57,6 +57,9 @@ def upload():
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= 128:
|
||||
return get_data_error_result(
|
||||
retmsg="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
@ -215,9 +218,11 @@ def rm():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
|
||||
if not DocumentService.delete(doc):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document removal)!")
|
||||
@ -245,7 +250,8 @@ def run():
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -261,7 +267,8 @@ def rename():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
doc.name.lower()).suffix:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
retmsg="The extension of file can't be changed",
|
||||
@ -294,7 +301,10 @@ def get(doc_id):
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
response.headers.set('Content-Type', 'image/%s' % ext.group(1))
|
||||
else:
|
||||
response.headers.set('Content-Type', 'application/%s' % ext.group(1))
|
||||
response.headers.set(
|
||||
'Content-Type',
|
||||
'application/%s' %
|
||||
ext.group(1))
|
||||
return response
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -313,9 +323,11 @@ def change_parser():
|
||||
if "parser_config" in req:
|
||||
if req["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else: return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
|
||||
if doc.type == FileType.VISUAL or re.search(
|
||||
r"\.(ppt|pptx|pages)$", doc.name):
|
||||
return get_data_error_result(retmsg="Not supported yet!")
|
||||
|
||||
e = DocumentService.update_by_id(doc.id,
|
||||
@ -332,7 +344,8 @@ def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
||||
@ -33,15 +33,21 @@ from api.utils.api_utils import get_json_result
|
||||
def create():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||
req["name"] = duplicate_name(
|
||||
KnowledgebaseService.query,
|
||||
name=req["name"],
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)
|
||||
try:
|
||||
req["id"] = get_uuid()
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
e, t = TenantService.get_by_id(current_user.id)
|
||||
if not e: return get_data_error_result(retmsg="Tenant not found.")
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Tenant not found.")
|
||||
req["embd_id"] = t.embd_id
|
||||
if not KnowledgebaseService.save(**req): return get_data_error_result()
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -54,21 +60,29 @@ def update():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
|
||||
return get_data_error_result(retmsg="Duplicated knowledgebase name.")
|
||||
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated knowledgebase name.")
|
||||
|
||||
del req["kb_id"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_data_error_result()
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase rename)!")
|
||||
|
||||
return get_json_result(data=kb.to_json())
|
||||
except Exception as e:
|
||||
@ -81,7 +95,9 @@ def detail():
|
||||
kb_id = request.args["kb_id"]
|
||||
try:
|
||||
kb = KnowledgebaseService.get_detail(kb_id)
|
||||
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
if not kb:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -96,7 +112,8 @@ def list():
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
||||
kbs = KnowledgebaseService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
||||
return get_json_result(data=kbs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -108,10 +125,15 @@ def list():
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
|
||||
if not KnowledgebaseService.update_by_id(
|
||||
req["kb_id"], {"status": StatusEnum.INVALID.value}):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase removal)!")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -48,30 +48,42 @@ def set_api_key():
|
||||
req["api_key"], llm.llm_name)
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
|
||||
if len(arr[0]) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
if not tc: raise Exception(m)
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
"temperature": 0.9})
|
||||
if not tc:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
|
||||
if msg: return get_data_error_result(retmsg=msg)
|
||||
if msg:
|
||||
return get_data_error_result(retmsg=msg)
|
||||
|
||||
llm = {
|
||||
"api_key": req["api_key"]
|
||||
}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req: llm[n] = req[n]
|
||||
if n in req:
|
||||
llm[n] = req[n]
|
||||
|
||||
if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm):
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm):
|
||||
for llm in LLMService.query(fid=factory):
|
||||
TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"])
|
||||
TenantLLMService.save(
|
||||
tenant_id=current_user.id,
|
||||
llm_factory=factory,
|
||||
llm_name=llm.llm_name,
|
||||
model_type=llm.model_type,
|
||||
api_key=req["api_key"])
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -105,17 +117,19 @@ def list():
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
||||
|
||||
res = {}
|
||||
for m in llms:
|
||||
if model_type and m["model_type"] != model_type: continue
|
||||
if m["fid"] not in res: res[m["fid"]] = []
|
||||
if model_type and m["model_type"] != model_type:
|
||||
continue
|
||||
if m["fid"] not in res:
|
||||
res[m["fid"]] = []
|
||||
res[m["fid"]].append(m)
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -40,13 +40,16 @@ def login():
|
||||
|
||||
email = request.json.get('email', "")
|
||||
users = UserService.query(email=email)
|
||||
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
if not users:
|
||||
return get_json_result(
|
||||
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
|
||||
password = request.json.get('password')
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except:
|
||||
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
||||
except BaseException:
|
||||
return get_json_result(
|
||||
data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
@ -57,7 +60,8 @@ def login():
|
||||
msg = "Welcome back!"
|
||||
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
|
||||
else:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg='Email and Password do not match!')
|
||||
|
||||
|
||||
@manager.route('/github_callback', methods=['GET'])
|
||||
@ -65,7 +69,7 @@ def github_callback():
|
||||
import requests
|
||||
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get('code')
|
||||
}, headers={"Accept": "application/json"})
|
||||
res = res.json()
|
||||
@ -96,15 +100,17 @@ def github_callback():
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
})
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
if not users:
|
||||
raise Exception('Register user failure.')
|
||||
if len(users) > 1:
|
||||
raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return redirect("/?auth=%s"%user.get_id())
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return redirect("/?error=%s"%str(e))
|
||||
return redirect("/?error=%s" % str(e))
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
@ -114,11 +120,18 @@ def github_callback():
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
headers = {"Accept": "application/json",
|
||||
'Authorization': f"token {access_token}"}
|
||||
res = requests.get(
|
||||
f"https://api.github.com/user?access_token={access_token}",
|
||||
headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
|
||||
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
|
||||
email_info = requests.get(
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers).json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email['primary'] == True),
|
||||
None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@ -138,13 +151,18 @@ def setting_user():
|
||||
request_data = request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
||||
if not check_password_hash(
|
||||
current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(
|
||||
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
||||
|
||||
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||||
if new_password:
|
||||
update_dict["password"] = generate_password_hash(
|
||||
decrypt(new_password))
|
||||
|
||||
for k in request_data.keys():
|
||||
if k in ["password", "new_password"]:continue
|
||||
if k in ["password", "new_password"]:
|
||||
continue
|
||||
update_dict[k] = request_data[k]
|
||||
|
||||
try:
|
||||
@ -152,7 +170,8 @@ def setting_user():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
return get_json_result(
|
||||
data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/info", methods=["GET"])
|
||||
@ -173,11 +192,11 @@ def rollback_user_registration(user_id):
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).excute()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def user_register(user_id, user):
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
@ -197,9 +216,14 @@ def user_register(user_id, user):
|
||||
}
|
||||
tenant_llm = []
|
||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
||||
tenant_llm.append({"tenant_id": user_id,
|
||||
"llm_factory": LLM_FACTORY,
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": API_KEY})
|
||||
|
||||
if not UserService.save(**user):return
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
@ -211,7 +235,8 @@ def user_register(user_id, user):
|
||||
def user_add():
|
||||
req = request.json
|
||||
if UserService.query(email=req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
@ -229,16 +254,19 @@ def user_add():
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_id, user_dict)
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
if not users:
|
||||
raise Exception('Register user failure.')
|
||||
if len(users) > 1:
|
||||
raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
||||
return cors_reponse(data=user.to_json(),
|
||||
auth=user.get_id(), retmsg="Welcome aboard!")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
return get_json_result(
|
||||
data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"])
|
||||
|
||||
Reference in New Issue
Block a user