deal with stop reason being length problem (#109)

This commit is contained in:
KevinHuSh
2024-03-07 16:12:01 +08:00
committed by GitHub
parent b69b5dd4e5
commit 2d7c9080f4
6 changed files with 59 additions and 27 deletions

View File

@ -176,7 +176,7 @@ def chat(dialog, messages, **kwargs):
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
llm = llm[0]
question = messages[-1]["content"]
questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
@ -184,7 +184,7 @@ def chat(dialog, messages, **kwargs):
## try to use sql if field mapping is good to go
if field_map:
stat_logger.info("Use SQL to retrieval.")
markdown_tbl, chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
if markdown_tbl:
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
@ -195,7 +195,9 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
for _ in range(len(questions)//2):
questions.append(questions[-1])
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
@ -224,13 +226,14 @@ def chat(dialog, messages, **kwargs):
def use_sql(question, field_map, tenant_id, chat_mdl):
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构根据的问题写出sql"
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构根据用户的问题列表写出最后一个问题对应的SQL"
user_promt = """
表名:{}
数据库表字段说明如下:
{}
问题:{}
问题如下
{}
请写出SQL且只要SQL不要有其他说明及文字。
""".format(
index_name(tenant_id),

View File

@ -100,12 +100,14 @@ def github_callback():
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return redirect("/?auth=%s"%user.get_id())
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return redirect("/?error=%s"%str(e))
return redirect("/?auth=%s"%user_id)
user = users[0]
login_user(user)
return redirect("/?auth=%s" % user.get_id())
def user_info_from_github(access_token):