mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 12:06:42 +08:00
llm configuation refine and trievalTest API refine (#40)
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
@ -177,6 +178,7 @@ def create():
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req.get("important_kwd", [])
|
||||
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
@ -223,7 +225,7 @@ def retrieval_test():
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value)
|
||||
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
|
||||
vector_similarity_weight, top, doc_ids)
|
||||
vector_similarity_weight, top, doc_ids)
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
@ -231,4 +233,3 @@ def retrieval_test():
|
||||
return get_json_result(data=False, retmsg=f'Index not found!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,22 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
import tiktoken
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import login_required
|
||||
from api.db.services.dialog_service import DialogService, ConversationService
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.services.kb_service import KnowledgebaseService
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.llm import ChatModel
|
||||
from rag.nlp import retrievaler
|
||||
from rag.nlp.query import EsQueryer
|
||||
from rag.utils import num_tokens_from_string, encoder
|
||||
|
||||
|
||||
@ -142,6 +136,27 @@ def message_fit_in(msg, max_length=4000):
|
||||
return max_length, msg
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("dialog_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":continue
|
||||
if m["role"] == "assistant" and not msg:continue
|
||||
msg.append({"role": m["role"], "content": m["content"]})
|
||||
try:
|
||||
e, dia = DialogService.get_by_id(req["dialog_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Dialog not found!")
|
||||
del req["dialog_id"]
|
||||
del req["messages"]
|
||||
return get_json_result(data=chat(dia, msg, **req))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
def chat(dialog, messages, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
llm = LLMService.query(llm_name=dialog.llm_id)
|
||||
@ -156,7 +171,7 @@ def chat(dialog, messages, **kwargs):
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
||||
|
||||
model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id)
|
||||
if not model_config: raise LookupError("LLM(%s) API key not found"%dialog.llm_id)
|
||||
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
|
||||
|
||||
question = messages[-1]["content"]
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
@ -183,25 +198,4 @@ def chat(dialog, messages, **kwargs):
|
||||
embd_mdl,
|
||||
tkweight=1-dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
return {"answer": answer, "retrieval": kbinfos}
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("dialog_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":continue
|
||||
if m["role"] == "assistant" and not msg:continue
|
||||
msg.append({"role": m["role"], "content": m["content"]})
|
||||
try:
|
||||
e, dia = DialogService.get_by_id(req["dialog_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Dialog not found!")
|
||||
del req["dialog_id"]
|
||||
del req["messages"]
|
||||
return get_json_result(data=chat(dia, msg, **req))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return {"answer": answer, "retrieval": kbinfos}
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -71,18 +71,12 @@ def my_llms():
|
||||
def list():
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
objs = [o.to_dict() for o in objs if o.api_key]
|
||||
fct = {}
|
||||
for o in objs:
|
||||
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
|
||||
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
|
||||
|
||||
mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
||||
for m in llms:
|
||||
m["available"] = False
|
||||
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
|
||||
m["available"] = True
|
||||
m["available"] = m.llm_name in mdlnms
|
||||
|
||||
res = {}
|
||||
for m in llms:
|
||||
if m["fid"] not in res: res[m["fid"]] = []
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,12 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
from flask import request, session, redirect, url_for
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from flask_login import login_required, current_user, login_user, logout_user
|
||||
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import TenantLLMService, LLMService
|
||||
from api.utils.api_utils import server_error_response, validate_request
|
||||
from api.utils import get_uuid, get_format_time, decrypt, download_img
|
||||
from api.db import UserTenantRole, LLMType
|
||||
@ -185,8 +187,6 @@ def rollback_user_registration(user_id):
|
||||
|
||||
|
||||
def user_register(user_id, user):
|
||||
|
||||
user_id = get_uuid()
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
@ -203,12 +203,14 @@ def user_register(user_id, user):
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
|
||||
tenant_llm = []
|
||||
for llm in LLMService.query(fid="Infiniflow"):
|
||||
tenant_llm.append({"tenant_id": user_id, "llm_factory": "Infiniflow", "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": "infiniflow API Key"})
|
||||
|
||||
if not UserService.save(**user):return
|
||||
TenantService.save(**tenant)
|
||||
UserTenantService.save(**usr_tenant)
|
||||
TenantLLMService.save(**tenant_llm)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
return UserService.query(email=user["email"])
|
||||
|
||||
|
||||
@ -218,6 +220,9 @@ def user_add():
|
||||
req = request.json
|
||||
if UserService.query(email=req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
user_dict = {
|
||||
"access_token": get_uuid(),
|
||||
|
||||
Reference in New Issue
Block a user