### What problem does this PR solve?

#4367

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-09 17:07:21 +08:00
committed by GitHub
parent f892d7d426
commit c5da3cdd97
30 changed files with 736 additions and 202 deletions

View File

@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api import settings
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name
from rag.settings import TAG_FLD
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens):
return knowledges
def label_question(question, kbs):
tags = None
tag_kb_ids = []
for kb in kbs:
if kb.parser_config.get("tag_kb_ids"):
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(all_tags, tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
tags = settings.retrievaler.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,
kb.parser_config.get("topn_tags", 3)
)
return tags
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs):
generate_keyword_ts = timer()
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
retrieval_ts = timer()
@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
1, 12, 0.1, 0.3, aggs=False,
rank_feature=label_question(question, kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.
@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
Steps::
- Comprehend the tag/label set.
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
Requirements
- The tags MUST be from the tag set.
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
- The relevance score must be range from 1 to 10.
- Keywords ONLY in output.
# TAG SET
{", ".join(all_tags)}
"""
for i, ex in enumerate(examples):
prompt += """
# Examples {}
### Text Content
{}
Output:
{}
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
prompt += f"""
# Real Data
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
kwd = re.sub(r".*?\{", "{", kwd)
return json.loads(kwd)