### What problem does this PR solve?

#4367

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-09 17:07:21 +08:00
committed by GitHub
parent f892d7d426
commit c5da3cdd97
30 changed files with 736 additions and 202 deletions

View File

@ -89,6 +89,7 @@ class ParserType(StrEnum):
AUDIO = "audio"
EMAIL = "email"
KG = "knowledge_graph"
TAG = "tag"
class FileSource(StrEnum):

View File

@ -133,7 +133,7 @@ def init_llm_factory():
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"})
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
## insert openai two embedding models to the current openai user.
# print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
@ -153,14 +153,7 @@ def init_llm_factory():
break
for kb_id in KnowledgebaseService.get_all_ids():
KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
"""
drop table llm;
drop table llm_factories;
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
alter table knowledgebase modify avatar longtext;
alter table user modify avatar longtext;
alter table dialog modify icon longtext;
"""
def add_graph_templates():

View File

@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api import settings
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name
from rag.settings import TAG_FLD
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens):
return knowledges
def label_question(question, kbs):
tags = None
tag_kb_ids = []
for kb in kbs:
if kb.parser_config.get("tag_kb_ids"):
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(all_tags, tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
tags = settings.retrievaler.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,
kb.parser_config.get("topn_tags", 3)
)
return tags
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs):
generate_keyword_ts = timer()
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
retrieval_ts = timer()
@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
1, 12, 0.1, 0.3, aggs=False,
rank_feature=label_question(question, kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.
@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
Steps::
- Comprehend the tag/label set.
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
Requirements
- The tags MUST be from the tag set.
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
- The relevance score must be range from 1 to 10.
- Keywords ONLY in output.
# TAG SET
{", ".join(all_tags)}
"""
for i, ex in enumerate(examples):
prompt += """
# Examples {}
### Text Content
{}
Output:
{}
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
prompt += f"""
# Real Data
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
kwd = re.sub(r".*?\{", "{", kwd)
return json.loads(kwd)

View File

@ -43,10 +43,7 @@ class File2DocumentService(CommonService):
def insert(cls, obj):
if not cls.save(**obj):
raise RuntimeError("Database error (File)!")
e, obj = cls.get_by_id(obj["id"])
if not e:
raise RuntimeError("Database error (File retrieval)!")
return obj
return File2Document(**obj)
@classmethod
@DB.connection_context()
@ -63,9 +60,8 @@ class File2DocumentService(CommonService):
def update_by_file_id(cls, file_id, obj):
obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now())
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
e, obj = cls.get_by_id(cls.model.id)
return obj
cls.model.update(obj).where(cls.model.id == file_id).execute()
return File2Document(**obj)
@classmethod
@DB.connection_context()

View File

@ -251,10 +251,7 @@ class FileService(CommonService):
def insert(cls, file):
if not cls.save(**file):
raise RuntimeError("Database error (File)!")
e, file = cls.get_by_id(file["id"])
if not e:
raise RuntimeError("Database error (File retrieval)!")
return file
return File(**file)
@classmethod
@DB.connection_context()

View File

@ -35,7 +35,10 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc, keywords):
page_number, items_per_page,
orderby, desc, keywords,
parser_id=None
):
fields = [
cls.model.id,
cls.model.avatar,
@ -67,6 +70,8 @@ class KnowledgebaseService(CommonService):
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if parser_id:
kbs = kbs.where(cls.model.parser_id == parser_id)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:

View File

@ -69,6 +69,7 @@ class TaskService(CommonService):
Knowledgebase.language,
Knowledgebase.embd_id,
Knowledgebase.pagerank,
Knowledgebase.parser_config.alias("kb_parser_config"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,