mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Support Ollama (#261)
### What problem does this PR solve? Issue link:#221 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -126,7 +126,7 @@ def message_fit_in(msg, max_length=4000):
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
||||
msg_ = [m for m in msg[:-1] if m["role"] == "system"]
|
||||
msg_.append(msg[-1])
|
||||
msg = msg_
|
||||
c = count()
|
||||
|
||||
@ -81,7 +81,7 @@ def upload():
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": filename_type(filename),
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
|
||||
@ -91,6 +91,57 @@ def set_api_key():
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/add_llm', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name", "model_type")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": req["llm_factory"],
|
||||
"model_type": req["model_type"],
|
||||
"llm_name": req["llm_name"],
|
||||
"api_base": req.get("api_base", ""),
|
||||
"api_key": "xxxxxxxxxxxxxxx"
|
||||
}
|
||||
|
||||
factory = req["llm_factory"]
|
||||
msg = ""
|
||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
mdl = EmbeddingModel[factory](
|
||||
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
||||
elif llm["model_type"] == LLMType.CHAT.value:
|
||||
mdl = ChatModel[factory](
|
||||
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
"temperature": 0.9})
|
||||
if not tc:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
||||
e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(retmsg=msg)
|
||||
|
||||
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
|
||||
TenantLLMService.save(**llm)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/my_llms', methods=['GET'])
|
||||
@login_required
|
||||
def my_llms():
|
||||
@ -125,6 +176,12 @@ def list():
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
||||
|
||||
llm_set = set([m["llm_name"] for m in llms])
|
||||
for o in objs:
|
||||
if not o.api_key:continue
|
||||
if o.llm_name in llm_set:continue
|
||||
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
||||
|
||||
res = {}
|
||||
for m in llms:
|
||||
if model_type and m["model_type"] != model_type:
|
||||
|
||||
@ -181,6 +181,10 @@ def user_info():
|
||||
|
||||
|
||||
def rollback_user_registration(user_id):
|
||||
try:
|
||||
UserService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
|
||||
@ -18,7 +18,7 @@ import time
|
||||
import uuid
|
||||
|
||||
from api.db import LLMType, UserTenantRole
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
|
||||
from api.db.services import UserService
|
||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
@ -100,16 +100,16 @@ factory_infos = [{
|
||||
"status": "1",
|
||||
},
|
||||
{
|
||||
"name": "Local",
|
||||
"name": "Ollama",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "Moonshot",
|
||||
"name": "Moonshot",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
}
|
||||
},
|
||||
# {
|
||||
# "name": "文心一言",
|
||||
# "logo": "",
|
||||
@ -230,20 +230,6 @@ def init_llm_factory():
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ---------------------- 本地 ----------------------
|
||||
{
|
||||
"fid": factory_infos[3]["name"],
|
||||
"llm_name": "qwen-14B-chat",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[3]["name"],
|
||||
"llm_name": "flag-embedding",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ Moonshot -----------------------
|
||||
{
|
||||
"fid": factory_infos[4]["name"],
|
||||
@ -282,6 +268,9 @@ def init_llm_factory():
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
|
||||
LLMService.filter_delete([LLM.fid=="Local"])
|
||||
|
||||
"""
|
||||
drop table llm;
|
||||
drop table llm_factories;
|
||||
@ -295,8 +284,7 @@ def init_llm_factory():
|
||||
def init_web_data():
|
||||
start_time = time.time()
|
||||
|
||||
if LLMFactoriesService.get_all().count() != len(factory_infos):
|
||||
init_llm_factory()
|
||||
init_llm_factory()
|
||||
if not UserService.get_all().count():
|
||||
init_superuser()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user