diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 11e7b3efe..575eea695 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -23,6 +23,8 @@ from api.db.services.dialog_service import DialogService, chat from api.utils import get_uuid import json +from rag.prompts import chunks_format + class ConversationService(CommonService): model = Conversation @@ -53,19 +55,7 @@ def structure_answer(conv, ans, message_id, session_id): reference = {} ans["reference"] = {} - def get_value(d, k1, k2): - return d.get(k1, d.get(k2)) - - chunk_list = [{ - "id": get_value(chunk, "chunk_id", "id"), - "content": get_value(chunk, "content", "content_with_weight"), - "document_id": get_value(chunk, "doc_id", "document_id"), - "document_name": get_value(chunk, "docnm_kwd", "document_name"), - "dataset_id": get_value(chunk, "kb_id", "dataset_id"), - "image_id": get_value(chunk, "image_id", "img_id"), - "positions": get_value(chunk, "positions", "position_int"), - "url": chunk.get("url") - } for chunk in reference.get("chunks", [])] + chunk_list = chunks_format(reference) reference["chunks"] = chunk_list ans["id"] = message_id diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 6126ef150..3cf599af6 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -30,7 +30,7 @@ from api import settings from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question +from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format from rag.utils import rmSpace, num_tokens_from_string from rag.utils.tavily_conn import Tavily @@ -511,7 +511,7 @@ def ask(question, kb_ids, tenant_id): if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" - return {"answer": answer, "reference": refs} + return {"answer": answer, "reference": chunks_format(refs)} answer = "" for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): diff --git a/rag/prompts.py b/rag/prompts.py index e8497fee3..4a07f175c 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -28,6 +28,22 @@ from rag.settings import TAG_FLD from rag.utils import num_tokens_from_string, encoder +def chunks_format(reference): + def get_value(d, k1, k2): + return d.get(k1, d.get(k2)) + + return [{ + "id": get_value(chunk, "chunk_id", "id"), + "content": get_value(chunk, "content", "content_with_weight"), + "document_id": get_value(chunk, "doc_id", "document_id"), + "document_name": get_value(chunk, "docnm_kwd", "document_name"), + "dataset_id": get_value(chunk, "kb_id", "dataset_id"), + "image_id": get_value(chunk, "image_id", "img_id"), + "positions": get_value(chunk, "positions", "position_int"), + "url": chunk.get("url") + } for chunk in reference.get("chunks", [])] + + def llm_id2llm_type(llm_id): llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) fnm = os.path.join(get_project_base_directory(), "conf")