mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: conversation completion can specify different model (#9485)
### What problem does this PR solve? Conversation completion can specify different model ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -29,7 +29,8 @@ from api.db.services.conversation_service import ConversationService, structure_
|
|||||||
from api.db.services.dialog_service import DialogService, ask, chat
|
from api.db.services.dialog_service import DialogService, ask, chat
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.user_service import UserTenantService, TenantService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
@ -66,8 +67,14 @@ def set_conversation():
|
|||||||
e, dia = DialogService.get_by_id(req["dialog_id"])
|
e, dia = DialogService.get_by_id(req["dialog_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Dialog not found")
|
return get_data_error_result(message="Dialog not found")
|
||||||
conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],"user_id": current_user.id,
|
conv = {
|
||||||
"reference":[],}
|
"id": conv_id,
|
||||||
|
"dialog_id": req["dialog_id"],
|
||||||
|
"name": name,
|
||||||
|
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
|
||||||
|
"user_id": current_user.id,
|
||||||
|
"reference": [],
|
||||||
|
}
|
||||||
ConversationService.save(**conv)
|
ConversationService.save(**conv)
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -174,6 +181,21 @@ def completion():
|
|||||||
continue
|
continue
|
||||||
msg.append(m)
|
msg.append(m)
|
||||||
message_id = msg[-1].get("id")
|
message_id = msg[-1].get("id")
|
||||||
|
chat_model_id = req.get("llm_id", "")
|
||||||
|
req.pop("llm_id", None)
|
||||||
|
|
||||||
|
chat_model_config = {}
|
||||||
|
for model_config in [
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"max_tokens",
|
||||||
|
]:
|
||||||
|
config = req.get(model_config)
|
||||||
|
if config:
|
||||||
|
chat_model_config[model_config] = config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
@ -190,13 +212,23 @@ def completion():
|
|||||||
conv.reference = [r for r in conv.reference if r]
|
conv.reference = [r for r in conv.reference if r]
|
||||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||||
|
|
||||||
|
if chat_model_id:
|
||||||
|
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
|
||||||
|
req.pop("chat_model_id", None)
|
||||||
|
req.pop("chat_model_config", None)
|
||||||
|
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
|
||||||
|
dia.llm_id = chat_model_id
|
||||||
|
dia.llm_setting = chat_model_config
|
||||||
|
|
||||||
|
is_embedded = bool(chat_model_id)
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **req):
|
for ans in chat(dia, msg, True, **req):
|
||||||
ans = structure_answer(conv, ans, message_id, conv.id)
|
ans = structure_answer(conv, ans, message_id, conv.id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
if not is_embedded:
|
||||||
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
@ -214,7 +246,8 @@ def completion():
|
|||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, **req):
|
for ans in chat(dia, msg, **req):
|
||||||
answer = structure_answer(conv, ans, message_id, conv.id)
|
answer = structure_answer(conv, ans, message_id, conv.id)
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
if not is_embedded:
|
||||||
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
return get_json_result(data=answer)
|
return get_json_result(data=answer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -881,11 +881,12 @@ class Search(DataBaseModel):
|
|||||||
# chat settings
|
# chat settings
|
||||||
"summary": False,
|
"summary": False,
|
||||||
"chat_id": "",
|
"chat_id": "",
|
||||||
|
# Leave it here for reference, don't need to set default values
|
||||||
"llm_setting": {
|
"llm_setting": {
|
||||||
"temperature": 0.1,
|
# "temperature": 0.1,
|
||||||
"top_p": 0.3,
|
# "top_p": 0.3,
|
||||||
"frequency_penalty": 0.7,
|
# "frequency_penalty": 0.7,
|
||||||
"presence_penalty": 0.4,
|
# "presence_penalty": 0.4,
|
||||||
},
|
},
|
||||||
"chat_settingcross_languages": [],
|
"chat_settingcross_languages": [],
|
||||||
"highlight": False,
|
"highlight": False,
|
||||||
@ -1020,4 +1021,4 @@ def migrate_db():
|
|||||||
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -99,7 +99,6 @@ class DialogService(CommonService):
|
|||||||
|
|
||||||
return list(chats.dicts())
|
return list(chats.dicts())
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None):
|
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None):
|
||||||
@ -256,9 +255,10 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict]):
|
def meta_filter(metas: dict, filters: list[dict]):
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
|
|
||||||
def filter_out(v2docs, operator, value):
|
def filter_out(v2docs, operator, value):
|
||||||
nonlocal doc_ids
|
nonlocal doc_ids
|
||||||
for input,docids in v2docs.items():
|
for input, docids in v2docs.items():
|
||||||
try:
|
try:
|
||||||
input = float(input)
|
input = float(input)
|
||||||
value = float(value)
|
value = float(value)
|
||||||
@ -389,7 +389,17 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
reasoner = DeepResearcher(
|
reasoner = DeepResearcher(
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
prompt_config,
|
prompt_config,
|
||||||
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments),
|
partial(
|
||||||
|
retriever.retrieval,
|
||||||
|
embd_mdl=embd_mdl,
|
||||||
|
tenant_ids=tenant_ids,
|
||||||
|
kb_ids=dialog.kb_ids,
|
||||||
|
page=1,
|
||||||
|
page_size=dialog.top_n,
|
||||||
|
similarity_threshold=0.2,
|
||||||
|
vector_similarity_weight=0.3,
|
||||||
|
doc_ids=attachments,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||||
|
|||||||
Reference in New Issue
Block a user