### What problem does this PR solve?

#4367

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-09 17:07:21 +08:00
committed by GitHub
parent f892d7d426
commit c5da3cdd97
30 changed files with 736 additions and 202 deletions

View File

@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource
from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat, keyword_extraction
from api.db.services.dialog_service import DialogService, chat, keyword_extraction, label_question
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
@ -840,7 +840,8 @@ def retrieval():
question += keyword_extraction(chat_mdl, question)
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
doc_ids, rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs))
for c in ranks["chunks"]:
c.pop("vector", None)
return get_json_result(data=ranks)

View File

@ -19,9 +19,10 @@ import json
from flask import request
from flask_login import login_required, current_user
from api.db.services.dialog_service import keyword_extraction
from api.db.services.dialog_service import keyword_extraction, label_question
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -124,10 +125,14 @@ def set():
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
d["question_kwd"] = req["question_kwd"]
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
if req.get("important_kwd"):
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
if req.get("question_kwd"):
d["question_kwd"] = req["question_kwd"]
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
if req.get("tag_kwd"):
d["tag_kwd"] = req["tag_kwd"]
if "available_int" in req:
d["available_int"] = req["available_int"]
@ -220,7 +225,7 @@ def create():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = doc.kb_id
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
@ -233,7 +238,7 @@ def create():
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d["pagerank_fea"] = kb.pagerank
d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
@ -294,12 +299,16 @@ def retrieval_test():
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
rank_feature=labels
)
for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels
return get_json_result(data=ranks)
except Exception as e:

View File

@ -25,7 +25,7 @@ from flask import request, Response
from flask_login import login_required, current_user
from api.db import LLMType
from api.db.services.dialog_service import DialogService, chat, ask
from api.db.services.dialog_service import DialogService, chat, ask, label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api import settings
@ -379,8 +379,11 @@ def mindmap():
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False)
question = req["question"]
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False,
rank_feature=label_question(question, [kb])
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
if "error" in mind_map:

View File

@ -30,6 +30,7 @@ from api.utils.api_utils import get_json_result
from api import settings
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD
@manager.route('/create', methods=['post']) # noqa: F821
@ -104,11 +105,11 @@ def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires pagerank_fea be non-zero!
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id)
@ -150,12 +151,14 @@ def list_kbs():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 150))
parser_id = request.args.get("parser_id")
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
kbs, total = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords)
[m["tenant_id"] for m in tenants], current_user.id, page_number,
items_per_page, orderby, desc, keywords, parser_id)
return get_json_result(data={"kbs": kbs, "total": total})
except Exception as e:
return server_error_response(e)
@ -199,3 +202,72 @@ def rm():
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/<kb_id>/tags', methods=['GET']) # noqa: F821
@login_required
def list_tags(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
tags = settings.retrievaler.all_tags(current_user.id, [kb_id])
return get_json_result(data=tags)
@manager.route('/tags', methods=['GET']) # noqa: F821
@login_required
def list_tags_from_kbs():
kb_ids = request.args.get("kb_ids", "").split(",")
for kb_id in kb_ids:
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
tags = settings.retrievaler.all_tags(current_user.id, kb_ids)
return get_json_result(data=tags)
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required
def rm_tags(kb_id):
req = request.json
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
for t in req["tags"]:
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
{"remove": {"tag_kwd": t}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required
def rename_tags(kb_id):
req = request.json
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)

View File

@ -73,7 +73,8 @@ def create(tenant_id):
chunk_method:
type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "knowledge_graph", "email"]
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
]
description: Chunking method.
parser_config:
type: object
@ -108,6 +109,7 @@ def create(tenant_id):
"one",
"knowledge_graph",
"email",
"tag"
]
check_validation = valid(
permission,
@ -302,7 +304,8 @@ def update(tenant_id, dataset_id):
chunk_method:
type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "knowledge_graph", "email"]
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
]
description: Updated chunking method.
parser_config:
type: object
@ -339,6 +342,7 @@ def update(tenant_id, dataset_id):
"one",
"knowledge_graph",
"email",
"tag"
]
check_validation = valid(
permission,

View File

@ -16,6 +16,7 @@
from flask import request, jsonify
from api.db import LLMType, ParserType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
@ -54,7 +55,8 @@ def retrieval(tenant_id):
page_size=top,
similarity_threshold=similarity_threshold,
vector_similarity_weight=0.3,
top=top
top=top,
rank_feature=label_question(question, [kb])
)
records = []
for c in ranks["chunks"]:

View File

@ -16,7 +16,7 @@
import pathlib
import datetime
from api.db.services.dialog_service import keyword_extraction
from api.db.services.dialog_service import keyword_extraction, label_question
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType
@ -276,6 +276,7 @@ def update_doc(tenant_id, dataset_id, document_id):
"one",
"knowledge_graph",
"email",
"tag"
}
if req.get("chunk_method") not in valid_chunk_method:
return get_error_data_result(
@ -1355,6 +1356,7 @@ def retrieval_test(tenant_id):
doc_ids,
rerank_mdl=rerank_mdl,
highlight=highlight,
rank_feature=label_question(question, kbs)
)
for c in ranks["chunks"]:
c.pop("vector", None)