diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index b44101135..147f76810 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -29,7 +29,8 @@ from api.db.services.conversation_service import ConversationService, structure_ from api.db.services.dialog_service import DialogService, ask, chat from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.db.services.user_service import UserTenantService, TenantService +from api.db.services.tenant_llm_service import TenantLLMService +from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from graphrag.general.mind_map_extractor import MindMapExtractor from rag.app.tag import label_question @@ -66,8 +67,14 @@ def set_conversation(): e, dia = DialogService.get_by_id(req["dialog_id"]) if not e: return get_data_error_result(message="Dialog not found") - conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],"user_id": current_user.id, - "reference":[],} + conv = { + "id": conv_id, + "dialog_id": req["dialog_id"], + "name": name, + "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}], + "user_id": current_user.id, + "reference": [], + } ConversationService.save(**conv) return get_json_result(data=conv) except Exception as e: @@ -174,6 +181,21 @@ def completion(): continue msg.append(m) message_id = msg[-1].get("id") + chat_model_id = req.get("llm_id", "") + req.pop("llm_id", None) + + chat_model_config = {} + for model_config in [ + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "max_tokens", + ]: + config = req.get(model_config) + if config: + chat_model_config[model_config] = config + try: e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: @@ -190,13 +212,23 @@ def completion(): conv.reference = [r for r in conv.reference if r] conv.reference.append({"chunks": [], "doc_aggs": []}) + if chat_model_id: + if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): + req.pop("chat_model_id", None) + req.pop("chat_model_config", None) + return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") + dia.llm_id = chat_model_id + dia.llm_setting = chat_model_config + + is_embedded = bool(chat_model_id) def stream(): nonlocal dia, msg, req, conv try: for ans in chat(dia, msg, True, **req): ans = structure_answer(conv, ans, message_id, conv.id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" - ConversationService.update_by_id(conv.id, conv.to_dict()) + if not is_embedded: + ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: traceback.print_exc() yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" @@ -214,7 +246,8 @@ def completion(): answer = None for ans in chat(dia, msg, **req): answer = structure_answer(conv, ans, message_id, conv.id) - ConversationService.update_by_id(conv.id, conv.to_dict()) + if not is_embedded: + ConversationService.update_by_id(conv.id, conv.to_dict()) break return get_json_result(data=answer) except Exception as e: diff --git a/api/db/db_models.py b/api/db/db_models.py index 1db9e1078..cdd946eb7 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -881,11 +881,12 @@ class Search(DataBaseModel): # chat settings "summary": False, "chat_id": "", + # Leave it here for reference, don't need to set default values "llm_setting": { - "temperature": 0.1, - "top_p": 0.3, - "frequency_penalty": 0.7, - "presence_penalty": 0.4, + # "temperature": 0.1, + # "top_p": 0.3, + # "frequency_penalty": 0.7, + # "presence_penalty": 0.4, }, "chat_settingcross_languages": [], "highlight": False, @@ -1020,4 +1021,4 @@ def migrate_db(): migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={}))) except Exception: pass - logging.disable(logging.NOTSET) \ No newline at end of file + logging.disable(logging.NOTSET) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index aa4f8ba4e..0f637bc70 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -99,7 +99,6 @@ class DialogService(CommonService): return list(chats.dicts()) - @classmethod @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None): @@ -256,9 +255,10 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): def meta_filter(metas: dict, filters: list[dict]): doc_ids = [] + def filter_out(v2docs, operator, value): nonlocal doc_ids - for input,docids in v2docs.items(): + for input, docids in v2docs.items(): try: input = float(input) value = float(value) @@ -389,7 +389,17 @@ def chat(dialog, messages, stream=True, **kwargs): reasoner = DeepResearcher( chat_mdl, prompt_config, - partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments), + partial( + retriever.retrieval, + embd_mdl=embd_mdl, + tenant_ids=tenant_ids, + kb_ids=dialog.kb_ids, + page=1, + page_size=dialog.top_n, + similarity_threshold=0.2, + vector_similarity_weight=0.3, + doc_ids=attachments, + ), ) for think in reasoner.thinking(kbinfos, " ".join(questions)):