mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add resume parser and fix bugs (#59)
* Update .gitignore * Update .gitignore * Add resume parser and fix bugs
This commit is contained in:
@ -47,17 +47,20 @@ def list():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
query = {
|
||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id))
|
||||
res = {"total": sres.total, "chunks": []}
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"],
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id].get("content_with_weight", ""),
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
@ -110,7 +113,7 @@ def get():
|
||||
"important_kwd")
|
||||
def set():
|
||||
req = request.json
|
||||
d = {"id": req["chunk_id"]}
|
||||
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = huqie.qie(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
@ -181,11 +184,12 @@ def create():
|
||||
md5 = hashlib.md5()
|
||||
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
|
||||
chunck_id = md5.hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])}
|
||||
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req.get("important_kwd", [])
|
||||
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
|
||||
@ -13,16 +13,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from api.db.services.dialog_service import DialogService, ConversationService
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle
|
||||
from api.settings import access_logger
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.llm import ChatModel
|
||||
from rag.nlp import retrievaler
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import num_tokens_from_string, encoder
|
||||
|
||||
|
||||
@ -163,6 +168,17 @@ def chat(dialog, messages, **kwargs):
|
||||
if not llm:
|
||||
raise LookupError("LLM(%s) not found"%dialog.llm_id)
|
||||
llm = llm[0]
|
||||
question = messages[-1]["content"]
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
## try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
markdown_tbl,chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
|
||||
if markdown_tbl:
|
||||
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":continue
|
||||
@ -170,9 +186,6 @@ def chat(dialog, messages, **kwargs):
|
||||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
||||
|
||||
question = messages[-1]["content"]
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
@ -196,4 +209,46 @@ def chat(dialog, messages, **kwargs):
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
for c in kbinfos["chunks"]:
|
||||
if c.get("vector"):del c["vector"]
|
||||
return {"answer": answer, "retrieval": kbinfos}
|
||||
return {"answer": answer, "retrieval": kbinfos}
|
||||
|
||||
|
||||
def use_sql(question,field_map, tenant_id, chat_mdl):
|
||||
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
|
||||
user_promt = """
|
||||
表名:{};
|
||||
数据库表字段说明如下:
|
||||
{}
|
||||
|
||||
问题:{}
|
||||
请写出SQL。
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k,v in field_map.items()]),
|
||||
question
|
||||
)
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
|
||||
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
if sql[:len("select ")].lower() != "select ":
|
||||
return None, None
|
||||
if sql[:len("select *")].lower() != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
|
||||
tbl = retrievaler.sql_retrieval(sql)
|
||||
if not tbl: return None, None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
|
||||
|
||||
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
|
||||
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
|
||||
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
|
||||
if not docid_idx or not docnm_idx:
|
||||
access_logger.error("SQL missing field: " + sql)
|
||||
return "\n".join([clmns, line, "\n".join(rows)]), []
|
||||
|
||||
rows = "\n".join([r+f"##{ii}$$" for ii,r in enumerate(rows)])
|
||||
docid_idx = list(docid_idx)[0]
|
||||
docnm_idx = list(docnm_idx)[0]
|
||||
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
|
||||
|
||||
@ -21,9 +21,6 @@ import flask
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.task_service import TaskService
|
||||
from rag.nlp import search
|
||||
from rag.utils import ELASTICSEARCH
|
||||
from api.db.services import duplicate_name
|
||||
@ -35,7 +32,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.utils.minio_conn import MINIO
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST'])
|
||||
@ -78,7 +75,8 @@ def upload():
|
||||
"type": filename_type(filename),
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob)
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob)
|
||||
})
|
||||
return get_json_result(data=doc.to_json())
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user