mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Tagging (#4426)
### What problem does this PR solve? #4367 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource
|
||||
from api.db.db_models import APIToken, Task, File
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.dialog_service import DialogService, chat, keyword_extraction
|
||||
from api.db.services.dialog_service import DialogService, chat, keyword_extraction, label_question
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
@ -840,7 +840,8 @@ def retrieval():
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl)
|
||||
doc_ids, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(question, kbs))
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
@ -19,9 +19,10 @@ import json
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.services.dialog_service import keyword_extraction
|
||||
from api.db.services.dialog_service import keyword_extraction, label_question
|
||||
from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -124,10 +125,14 @@ def set():
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
|
||||
d["question_kwd"] = req["question_kwd"]
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
|
||||
if req.get("important_kwd"):
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
|
||||
if req.get("question_kwd"):
|
||||
d["question_kwd"] = req["question_kwd"]
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
|
||||
if req.get("tag_kwd"):
|
||||
d["tag_kwd"] = req["tag_kwd"]
|
||||
if "available_int" in req:
|
||||
d["available_int"] = req["available_int"]
|
||||
|
||||
@ -220,7 +225,7 @@ def create():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = doc.kb_id
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
@ -233,7 +238,7 @@ def create():
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d["pagerank_fea"] = kb.pagerank
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
@ -294,12 +299,16 @@ def retrieval_test():
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
||||
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
||||
rank_feature=labels
|
||||
)
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
|
||||
@ -25,7 +25,7 @@ from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.dialog_service import DialogService, chat, ask
|
||||
from api.db.services.dialog_service import DialogService, chat, ask, label_question
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
||||
from api import settings
|
||||
@ -379,8 +379,11 @@ def mindmap():
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
||||
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
||||
0.3, 0.3, aggs=False)
|
||||
question = req["question"]
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
||||
0.3, 0.3, aggs=False,
|
||||
rank_feature=label_question(question, [kb])
|
||||
)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
||||
if "error" in mind_map:
|
||||
|
||||
@ -30,6 +30,7 @@ from api.utils.api_utils import get_json_result
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@ -104,11 +105,11 @@ def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires pagerank_fea be non-zero!
|
||||
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
@ -150,12 +151,14 @@ def list_kbs():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 150))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
kbs, total = KnowledgebaseService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords)
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, parser_id)
|
||||
return get_json_result(data={"kbs": kbs, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -199,3 +202,72 @@ def rm():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/<kb_id>/tags', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_tags(kb_id):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
tags = settings.retrievaler.all_tags(current_user.id, [kb_id])
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@manager.route('/tags', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_tags_from_kbs():
|
||||
kb_ids = request.args.get("kb_ids", "").split(",")
|
||||
for kb_id in kb_ids:
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
tags = settings.retrievaler.all_tags(current_user.id, kb_ids)
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rm_tags(kb_id):
|
||||
req = request.json
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
for t in req["tags"]:
|
||||
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": t}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rename_tags(kb_id):
|
||||
req = request.json
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
return get_json_result(data=True)
|
||||
@ -73,7 +73,8 @@ def create(tenant_id):
|
||||
chunk_method:
|
||||
type: string
|
||||
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
||||
"presentation", "picture", "one", "knowledge_graph", "email"]
|
||||
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
|
||||
]
|
||||
description: Chunking method.
|
||||
parser_config:
|
||||
type: object
|
||||
@ -108,6 +109,7 @@ def create(tenant_id):
|
||||
"one",
|
||||
"knowledge_graph",
|
||||
"email",
|
||||
"tag"
|
||||
]
|
||||
check_validation = valid(
|
||||
permission,
|
||||
@ -302,7 +304,8 @@ def update(tenant_id, dataset_id):
|
||||
chunk_method:
|
||||
type: string
|
||||
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
||||
"presentation", "picture", "one", "knowledge_graph", "email"]
|
||||
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
|
||||
]
|
||||
description: Updated chunking method.
|
||||
parser_config:
|
||||
type: object
|
||||
@ -339,6 +342,7 @@ def update(tenant_id, dataset_id):
|
||||
"one",
|
||||
"knowledge_graph",
|
||||
"email",
|
||||
"tag"
|
||||
]
|
||||
check_validation = valid(
|
||||
permission,
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from flask import request, jsonify
|
||||
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.dialog_service import label_question
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
@ -54,7 +55,8 @@ def retrieval(tenant_id):
|
||||
page_size=top,
|
||||
similarity_threshold=similarity_threshold,
|
||||
vector_similarity_weight=0.3,
|
||||
top=top
|
||||
top=top,
|
||||
rank_feature=label_question(question, [kb])
|
||||
)
|
||||
records = []
|
||||
for c in ranks["chunks"]:
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
import pathlib
|
||||
import datetime
|
||||
|
||||
from api.db.services.dialog_service import keyword_extraction
|
||||
from api.db.services.dialog_service import keyword_extraction, label_question
|
||||
from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.nlp import rag_tokenizer
|
||||
from api.db import LLMType, ParserType
|
||||
@ -276,6 +276,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
"one",
|
||||
"knowledge_graph",
|
||||
"email",
|
||||
"tag"
|
||||
}
|
||||
if req.get("chunk_method") not in valid_chunk_method:
|
||||
return get_error_data_result(
|
||||
@ -1355,6 +1356,7 @@ def retrieval_test(tenant_id):
|
||||
doc_ids,
|
||||
rerank_mdl=rerank_mdl,
|
||||
highlight=highlight,
|
||||
rank_feature=label_question(question, kbs)
|
||||
)
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
|
||||
@ -89,6 +89,7 @@ class ParserType(StrEnum):
|
||||
AUDIO = "audio"
|
||||
EMAIL = "email"
|
||||
KG = "knowledge_graph"
|
||||
TAG = "tag"
|
||||
|
||||
|
||||
class FileSource(StrEnum):
|
||||
|
||||
@ -133,7 +133,7 @@ def init_llm_factory():
|
||||
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
||||
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
|
||||
TenantService.filter_update([1 == 1], {
|
||||
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"})
|
||||
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
|
||||
## insert openai two embedding models to the current openai user.
|
||||
# print("Start to insert 2 OpenAI embedding models...")
|
||||
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
|
||||
@ -153,14 +153,7 @@ def init_llm_factory():
|
||||
break
|
||||
for kb_id in KnowledgebaseService.get_all_ids():
|
||||
KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
|
||||
"""
|
||||
drop table llm;
|
||||
drop table llm_factories;
|
||||
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
|
||||
alter table knowledgebase modify avatar longtext;
|
||||
alter table user modify avatar longtext;
|
||||
alter table dialog modify icon longtext;
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def add_graph_templates():
|
||||
|
||||
@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api import settings
|
||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.nlp.search import index_name
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens):
|
||||
return knowledges
|
||||
|
||||
|
||||
def label_question(question, kbs):
|
||||
tags = None
|
||||
tag_kb_ids = []
|
||||
for kb in kbs:
|
||||
if kb.parser_config.get("tag_kb_ids"):
|
||||
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
|
||||
if tag_kb_ids:
|
||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
set_tags_to_cache(all_tags, tag_kb_ids)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
tags = settings.retrievaler.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
kb.parser_config.get("topn_tags", 3)
|
||||
)
|
||||
return tags
|
||||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
|
||||
@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
generate_keyword_ts = timer()
|
||||
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs)
|
||||
)
|
||||
|
||||
retrieval_ts = timer()
|
||||
|
||||
@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
max_tokens = chat_mdl.max_length
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
|
||||
1, 12, 0.1, 0.3, aggs=False,
|
||||
rank_feature=label_question(question, kbs)
|
||||
)
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
prompt = """
|
||||
Role: You're a smart assistant. Your name is Miss R.
|
||||
@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
|
||||
|
||||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||||
prompt = f"""
|
||||
Role: You're a text analyzer.
|
||||
|
||||
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
|
||||
|
||||
Steps::
|
||||
- Comprehend the tag/label set.
|
||||
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
|
||||
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
|
||||
|
||||
Requirements
|
||||
- The tags MUST be from the tag set.
|
||||
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
|
||||
- The relevance score must be range from 1 to 10.
|
||||
- Keywords ONLY in output.
|
||||
|
||||
# TAG SET
|
||||
{", ".join(all_tags)}
|
||||
|
||||
"""
|
||||
for i, ex in enumerate(examples):
|
||||
prompt += """
|
||||
# Examples {}
|
||||
### Text Content
|
||||
{}
|
||||
|
||||
Output:
|
||||
{}
|
||||
|
||||
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
|
||||
|
||||
prompt += f"""
|
||||
# Real Data
|
||||
### Text Content
|
||||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
raise Exception(kwd)
|
||||
|
||||
kwd = re.sub(r".*?\{", "{", kwd)
|
||||
return json.loads(kwd)
|
||||
@ -43,10 +43,7 @@ class File2DocumentService(CommonService):
|
||||
def insert(cls, obj):
|
||||
if not cls.save(**obj):
|
||||
raise RuntimeError("Database error (File)!")
|
||||
e, obj = cls.get_by_id(obj["id"])
|
||||
if not e:
|
||||
raise RuntimeError("Database error (File retrieval)!")
|
||||
return obj
|
||||
return File2Document(**obj)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -63,9 +60,8 @@ class File2DocumentService(CommonService):
|
||||
def update_by_file_id(cls, file_id, obj):
|
||||
obj["update_time"] = current_timestamp()
|
||||
obj["update_date"] = datetime_format(datetime.now())
|
||||
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||
e, obj = cls.get_by_id(cls.model.id)
|
||||
return obj
|
||||
cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||
return File2Document(**obj)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
||||
@ -251,10 +251,7 @@ class FileService(CommonService):
|
||||
def insert(cls, file):
|
||||
if not cls.save(**file):
|
||||
raise RuntimeError("Database error (File)!")
|
||||
e, file = cls.get_by_id(file["id"])
|
||||
if not e:
|
||||
raise RuntimeError("Database error (File retrieval)!")
|
||||
return file
|
||||
return File(**file)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
||||
@ -35,7 +35,10 @@ class KnowledgebaseService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page, orderby, desc, keywords):
|
||||
page_number, items_per_page,
|
||||
orderby, desc, keywords,
|
||||
parser_id=None
|
||||
):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
@ -67,6 +70,8 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if parser_id:
|
||||
kbs = kbs.where(cls.model.parser_id == parser_id)
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
|
||||
@ -69,6 +69,7 @@ class TaskService(CommonService):
|
||||
Knowledgebase.language,
|
||||
Knowledgebase.embd_id,
|
||||
Knowledgebase.pagerank,
|
||||
Knowledgebase.parser_config.alias("kb_parser_config"),
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
|
||||
@ -140,7 +140,7 @@ def init_settings():
|
||||
API_KEY = LLM.get("api_key", "")
|
||||
PARSERS = LLM.get(
|
||||
"parsers",
|
||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag")
|
||||
|
||||
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||
|
||||
@ -173,6 +173,7 @@ def validate_request(*args, **kwargs):
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def not_allowed_parameters(*params):
|
||||
def decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -182,7 +183,9 @@ def not_allowed_parameters(*params):
|
||||
return get_json_result(
|
||||
code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ -207,6 +210,7 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
|
||||
response = {"code": code, "message": message, "data": data}
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def apikey_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
@ -282,17 +286,18 @@ def construct_error_response(e):
|
||||
def token_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
authorization_str=flask_request.headers.get('Authorization')
|
||||
authorization_str = flask_request.headers.get('Authorization')
|
||||
if not authorization_str:
|
||||
return get_json_result(data=False,message="`Authorization` can't be empty")
|
||||
authorization_list=authorization_str.split()
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_list = authorization_str.split()
|
||||
if len(authorization_list) < 2:
|
||||
return get_json_result(data=False,message="Please check your authorization format.")
|
||||
return get_json_result(data=False, message="Please check your authorization format.")
|
||||
token = authorization_list[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!', code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
data=False, message='Authentication error: API key is invalid!',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
kwargs['tenant_id'] = objs[0].tenant_id
|
||||
return func(*args, **kwargs)
|
||||
@ -330,35 +335,41 @@ def generate_confirmation_token(tenent_id):
|
||||
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
||||
|
||||
|
||||
def valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method):
|
||||
if valid_parameter(permission,valid_permission):
|
||||
return valid_parameter(permission,valid_permission)
|
||||
if valid_parameter(language,valid_language):
|
||||
return valid_parameter(language,valid_language)
|
||||
if valid_parameter(chunk_method,valid_chunk_method):
|
||||
return valid_parameter(chunk_method,valid_chunk_method)
|
||||
def valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method):
|
||||
if valid_parameter(permission, valid_permission):
|
||||
return valid_parameter(permission, valid_permission)
|
||||
if valid_parameter(language, valid_language):
|
||||
return valid_parameter(language, valid_language)
|
||||
if valid_parameter(chunk_method, valid_chunk_method):
|
||||
return valid_parameter(chunk_method, valid_chunk_method)
|
||||
|
||||
def valid_parameter(parameter,valid_values):
|
||||
|
||||
def valid_parameter(parameter, valid_values):
|
||||
if parameter and parameter not in valid_values:
|
||||
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
|
||||
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
|
||||
|
||||
def get_parser_config(chunk_method,parser_config):
|
||||
|
||||
def get_parser_config(chunk_method, parser_config):
|
||||
if parser_config:
|
||||
return parser_config
|
||||
if not chunk_method:
|
||||
chunk_method = "naive"
|
||||
key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"use_raptor": False}},
|
||||
"qa":{"raptor":{"use_raptor":False}},
|
||||
"resume":None,
|
||||
"manual":{"raptor":{"use_raptor":False}},
|
||||
"table":None,
|
||||
"paper":{"raptor":{"use_raptor":False}},
|
||||
"book":{"raptor":{"use_raptor":False}},
|
||||
"laws":{"raptor":{"use_raptor":False}},
|
||||
"presentation":{"raptor":{"use_raptor":False}},
|
||||
"one":None,
|
||||
"knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]},
|
||||
"email":None,
|
||||
"picture":None}
|
||||
parser_config=key_mapping[chunk_method]
|
||||
return parser_config
|
||||
key_mapping = {
|
||||
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True,
|
||||
"raptor": {"use_raptor": False}},
|
||||
"qa": {"raptor": {"use_raptor": False}},
|
||||
"tag": None,
|
||||
"resume": None,
|
||||
"manual": {"raptor": {"use_raptor": False}},
|
||||
"table": None,
|
||||
"paper": {"raptor": {"use_raptor": False}},
|
||||
"book": {"raptor": {"use_raptor": False}},
|
||||
"laws": {"raptor": {"use_raptor": False}},
|
||||
"presentation": {"raptor": {"use_raptor": False}},
|
||||
"one": None,
|
||||
"knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?",
|
||||
"entity_types": ["organization", "person", "location", "event", "time"]},
|
||||
"email": None,
|
||||
"picture": None}
|
||||
parser_config = key_mapping[chunk_method]
|
||||
return parser_config
|
||||
|
||||
Reference in New Issue
Block a user