mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Support iframe chatbot. (#3961)
### What problem does this PR solve? #3909 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
from agent.canvas import Canvas
|
||||
from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
|
||||
@ -58,6 +59,8 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(cvs.dsl, tenant_id)
|
||||
canvas.reset()
|
||||
message_id = str(uuid4())
|
||||
|
||||
if not session_id:
|
||||
session_id = get_uuid()
|
||||
@ -84,40 +87,24 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
||||
return
|
||||
conv = API4Conversation(**conv)
|
||||
else:
|
||||
session_id = session_id
|
||||
e, conv = API4ConversationService.get_by_id(session_id)
|
||||
assert e, "Session not found!"
|
||||
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
|
||||
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
messages = conv.message
|
||||
question = {
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"id": str(uuid4())
|
||||
}
|
||||
messages.append(question)
|
||||
msg = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
canvas.messages.append({"role": "user", "content": question, "id": message_id})
|
||||
canvas.add_user_input(question)
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
conv.message.append({
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"id": message_id
|
||||
})
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "content": ""}
|
||||
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
|
||||
if stream:
|
||||
try:
|
||||
for ans in canvas.run(stream=stream):
|
||||
@ -141,6 +128,7 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
||||
conv.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
conv.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
|
||||
@ -21,7 +21,6 @@ from api.db.services.common_service import CommonService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.utils import get_uuid
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
@ -49,30 +48,35 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
reference = ans["reference"]
|
||||
if not isinstance(reference, dict):
|
||||
reference = {}
|
||||
temp_reference = deepcopy(ans["reference"])
|
||||
if not conv.reference:
|
||||
conv.reference.append(temp_reference)
|
||||
else:
|
||||
conv.reference[-1] = temp_reference
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["reference"] = {}
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
chunk_list = [{
|
||||
"id": chunk["chunk_id"],
|
||||
"content": chunk.get("content") if chunk.get("content") else chunk.get("content_with_content"),
|
||||
"document_id": chunk["doc_id"],
|
||||
"document_name": chunk["docnm_kwd"],
|
||||
"dataset_id": chunk["kb_id"],
|
||||
"image_id": chunk["image_id"],
|
||||
"similarity": chunk["similarity"],
|
||||
"vector_similarity": chunk["vector_similarity"],
|
||||
"term_similarity": chunk["term_similarity"],
|
||||
"positions": chunk["positions"],
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
} for chunk in reference.get("chunks", [])]
|
||||
|
||||
reference["chunks"] = chunk_list
|
||||
ans["id"] = message_id
|
||||
ans["session_id"] = session_id
|
||||
|
||||
if not conv:
|
||||
return ans
|
||||
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
||||
conv.message.append({"role": "assistant", "content": ans["answer"], "id": message_id})
|
||||
else:
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
if conv.reference:
|
||||
conv.reference[-1] = reference
|
||||
return ans
|
||||
|
||||
|
||||
@ -199,7 +203,6 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
if stream:
|
||||
|
||||
@ -18,6 +18,7 @@ import binascii
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from timeit import default_timer as timer
|
||||
import datetime
|
||||
@ -108,6 +109,32 @@ def llm_id2llm_type(llm_id):
|
||||
return llm["model_type"].strip(",")[-1]
|
||||
|
||||
|
||||
def kb_prompt(kbinfos, max_tokens):
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
used_token_count = 0
|
||||
chunks_num = 0
|
||||
for i, c in enumerate(knowledges):
|
||||
used_token_count += num_tokens_from_string(c)
|
||||
chunks_num += 1
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
knowledges = knowledges[:i]
|
||||
break
|
||||
|
||||
doc2chunks = defaultdict(list)
|
||||
for i, ck in enumerate(kbinfos["chunks"]):
|
||||
if i >= chunks_num:
|
||||
break
|
||||
doc2chunks["docnm_kwd"].append(ck["content_with_weight"])
|
||||
|
||||
knowledges = []
|
||||
for nm, chunks in doc2chunks.items():
|
||||
txt = f"Document: {nm} \nContains the following relevant fragments:\n"
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
txt += f"{i}. {chunk}\n"
|
||||
knowledges.append(txt)
|
||||
return knowledges
|
||||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
st = timer()
|
||||
@ -195,32 +222,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
||||
|
||||
# Group chunks by document ID
|
||||
doc_chunks = {}
|
||||
for ck in kbinfos["chunks"]:
|
||||
doc_id = ck["doc_id"]
|
||||
if doc_id not in doc_chunks:
|
||||
doc_chunks[doc_id] = []
|
||||
doc_chunks[doc_id].append(ck["content_with_weight"])
|
||||
|
||||
# Create knowledges list with grouped chunks
|
||||
knowledges = []
|
||||
for doc_id, chunks in doc_chunks.items():
|
||||
# Find the corresponding document name
|
||||
doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
|
||||
|
||||
# Create a header for the document
|
||||
doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
|
||||
|
||||
# Add numbered fragments
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
doc_knowledge += f"{i}. {chunk}\n"
|
||||
|
||||
knowledges.append(doc_knowledge)
|
||||
|
||||
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
logging.debug(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
retrieval_tm = timer()
|
||||
@ -603,7 +605,6 @@ def tts(tts_mdl, text):
|
||||
|
||||
def ask(question, kb_ids, tenant_id):
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
tenant_ids = [kb.tenant_id for kb in kbs]
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
@ -612,45 +613,9 @@ def ask(question, kb_ids, tenant_id):
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
max_tokens = chat_mdl.max_length
|
||||
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
|
||||
used_token_count = 0
|
||||
chunks_num = 0
|
||||
for i, c in enumerate(knowledges):
|
||||
used_token_count += num_tokens_from_string(c)
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
knowledges = knowledges[:i]
|
||||
chunks_num = chunks_num + 1
|
||||
break
|
||||
|
||||
# Group chunks by document ID
|
||||
doc_chunks = {}
|
||||
counter_chunks = 0
|
||||
for ck in kbinfos["chunks"]:
|
||||
if counter_chunks < chunks_num:
|
||||
counter_chunks = counter_chunks + 1
|
||||
doc_id = ck["doc_id"]
|
||||
if doc_id not in doc_chunks:
|
||||
doc_chunks[doc_id] = []
|
||||
doc_chunks[doc_id].append(ck["content_with_weight"])
|
||||
|
||||
# Create knowledges list with grouped chunks
|
||||
knowledges = []
|
||||
for doc_id, chunks in doc_chunks.items():
|
||||
# Find the corresponding document name
|
||||
doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
|
||||
|
||||
# Create a header for the document
|
||||
doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
|
||||
|
||||
# Add numbered fragments
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
doc_knowledge += f"{i}. {chunk}\n"
|
||||
|
||||
knowledges.append(doc_knowledge)
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
prompt = """
|
||||
Role: You're a smart assistant. Your name is Miss R.
|
||||
Task: Summarize the information from knowledge bases and answer user's question.
|
||||
@ -660,25 +625,25 @@ def ask(question, kb_ids, tenant_id):
|
||||
- Answer with markdown format text.
|
||||
- Answer in language of user's question.
|
||||
- DO NOT make things up, especially for numbers.
|
||||
|
||||
|
||||
### Information from knowledge bases
|
||||
%s
|
||||
|
||||
|
||||
The above is information from knowledge bases.
|
||||
|
||||
"""%"\n".join(knowledges)
|
||||
|
||||
""" % "\n".join(knowledges)
|
||||
msg = [{"role": "user", "content": question}]
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal knowledges, kbinfos, prompt
|
||||
answer, idx = retr.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
@ -691,7 +656,7 @@ def ask(question, kb_ids, tenant_id):
|
||||
del c["vector"]
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
|
||||
Reference in New Issue
Block a user