mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: migrate chat models to LiteLLM (#9394)
### What problem does this PR solve? All models pass the mock response tests, which means that if a model can return the correct response, everything should work as expected. However, not all models have been fully tested in a real environment, the real API_KEY. I suggest actively monitoring the refactored models over the coming period to ensure they work correctly and fixing them step by step, or waiting to merge until most have been tested in practical environment. ### Type of change - [x] Refactoring
This commit is contained in:
@ -57,6 +57,7 @@ def set_api_key():
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
extra = {"provider": factory}
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
@ -73,7 +74,7 @@ def set_api_key():
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, 'max_tokens': 50})
|
||||
@ -204,6 +205,7 @@ def add_llm():
|
||||
|
||||
msg = ""
|
||||
mdl_nm = llm["llm_name"].split("___")[0]
|
||||
extra = {"provider": factory}
|
||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
@ -221,7 +223,8 @@ def add_llm():
|
||||
mdl = ChatModel[factory](
|
||||
key=llm['api_key'],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"]
|
||||
base_url=llm["api_base"],
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
@ -312,12 +315,12 @@ def delete_factory():
|
||||
def my_llms():
|
||||
try:
|
||||
include_details = request.args.get('include_details', 'false').lower() == 'true'
|
||||
|
||||
|
||||
if include_details:
|
||||
res = {}
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
factories = LLMFactoriesService.query(status=StatusEnum.VALID.value)
|
||||
|
||||
|
||||
for o in objs:
|
||||
o_dict = o.to_dict()
|
||||
factory_tags = None
|
||||
@ -325,13 +328,13 @@ def my_llms():
|
||||
if f.name == o_dict["llm_factory"]:
|
||||
factory_tags = f.tags
|
||||
break
|
||||
|
||||
|
||||
if o_dict["llm_factory"] not in res:
|
||||
res[o_dict["llm_factory"]] = {
|
||||
"tags": factory_tags,
|
||||
"llm": []
|
||||
}
|
||||
|
||||
|
||||
res[o_dict["llm_factory"]]["llm"].append({
|
||||
"type": o_dict["model_type"],
|
||||
"name": o_dict["llm_name"],
|
||||
@ -352,7 +355,7 @@ def my_llms():
|
||||
"name": o["llm_name"],
|
||||
"used_token": o["used_tokens"]
|
||||
})
|
||||
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
Reference in New Issue
Block a user