Fix: Document.update() now refreshes object data (#8068)

### What problem does this PR solve?

#8067 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Liu An
2025-06-05 12:46:29 +08:00
committed by GitHub
parent 640fca7dc9
commit 8b7c424617
2 changed files with 123 additions and 166 deletions

View File

@ -13,38 +13,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import pathlib
import datetime import datetime
import logging
from rag.app.qa import rmPrefix, beAdoc import pathlib
from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api import settings
import xxhash
import re import re
from api.utils.api_utils import token_required
from api.db.db_models import Task
from api.db.services.task_service import TaskService, queue_tasks
from api.utils.api_utils import server_error_response
from api.utils.api_utils import get_result, get_error_data_result
from io import BytesIO from io import BytesIO
import xxhash
from flask import request, send_file from flask import request, send_file
from api.db import FileSource, TaskStatus, FileType from peewee import OperationalError
from api.db.db_models import File from pydantic import BaseModel, Field, validator
from api import settings
from api.db import FileSource, FileType, LLMType, ParserType, TaskStatus
from api.db.db_models import File, Task
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import construct_json_result, get_parser_config, check_duplicate_ids from api.db.services.llm_service import LLMBundle, TenantLLMService
from rag.nlp import search from api.db.services.task_service import TaskService, queue_tasks
from rag.prompts import keyword_extraction from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
from rag.prompts import keyword_extraction
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from pydantic import BaseModel, Field, validator
MAXIMUM_OF_UPLOADING_FILES = 256 MAXIMUM_OF_UPLOADING_FILES = 256
@ -60,7 +56,7 @@ class Chunk(BaseModel):
available: bool = True available: bool = True
positions: list[list[int]] = Field(default_factory=list) positions: list[list[int]] = Field(default_factory=list)
@validator('positions') @validator("positions")
def validate_positions(cls, value): def validate_positions(cls, value):
for sublist in value: for sublist in value:
if len(sublist) != 5: if len(sublist) != 5:
@ -128,20 +124,14 @@ def upload(dataset_id, tenant_id):
description: Processing status. description: Processing status.
""" """
if "file" not in request.files: if "file" not in request.files:
return get_error_data_result( return get_error_data_result(message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
)
file_objs = request.files.getlist("file") file_objs = request.files.getlist("file")
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == "": if file_obj.filename == "":
return get_result( return get_result(message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
)
if len(file_obj.filename.encode("utf-8")) >= 128: if len(file_obj.filename.encode("utf-8")) >= 128:
return get_result( return get_result(message="File name should be less than 128 bytes.", code=settings.RetCode.ARGUMENT_ERROR)
message="File name should be less than 128 bytes.", code=settings.RetCode.ARGUMENT_ERROR """
)
'''
# total size # total size
total_size = 0 total_size = 0
for file_obj in file_objs: for file_obj in file_objs:
@ -154,7 +144,7 @@ def upload(dataset_id, tenant_id):
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
code=settings.RetCode.ARGUMENT_ERROR, code=settings.RetCode.ARGUMENT_ERROR,
) )
''' """
e, kb = KnowledgebaseService.get_by_id(dataset_id) e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e: if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!") raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
@ -236,8 +226,7 @@ def update_doc(tenant_id, dataset_id, document_id):
return get_error_data_result(message="You don't own the dataset.") return get_error_data_result(message="You don't own the dataset.")
e, kb = KnowledgebaseService.get_by_id(dataset_id) e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e: if not e:
return get_error_data_result( return get_error_data_result(message="Can't find this knowledgebase!")
message="Can't find this knowledgebase!")
doc = DocumentService.query(kb_id=dataset_id, id=document_id) doc = DocumentService.query(kb_id=dataset_id, id=document_id)
if not doc: if not doc:
return get_error_data_result(message="The dataset doesn't own the document.") return get_error_data_result(message="The dataset doesn't own the document.")
@ -263,19 +252,14 @@ def update_doc(tenant_id, dataset_id, document_id):
message="The name should be less than 128 bytes.", message="The name should be less than 128 bytes.",
code=settings.RetCode.ARGUMENT_ERROR, code=settings.RetCode.ARGUMENT_ERROR,
) )
if ( if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
pathlib.Path(req["name"].lower()).suffix
!= pathlib.Path(doc.name.lower()).suffix
):
return get_result( return get_result(
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=settings.RetCode.ARGUMENT_ERROR, code=settings.RetCode.ARGUMENT_ERROR,
) )
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]: if d.name == req["name"]:
return get_error_data_result( return get_error_data_result(message="Duplicated document name in the same dataset.")
message="Duplicated document name in the same dataset."
)
if not DocumentService.update_by_id(document_id, {"name": req["name"]}): if not DocumentService.update_by_id(document_id, {"name": req["name"]}):
return get_error_data_result(message="Database error (Document rename)!") return get_error_data_result(message="Database error (Document rename)!")
@ -287,25 +271,9 @@ def update_doc(tenant_id, dataset_id, document_id):
if "parser_config" in req: if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"]) DocumentService.update_parser_config(doc.id, req["parser_config"])
if "chunk_method" in req: if "chunk_method" in req:
valid_chunk_method = { valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"}
"naive",
"manual",
"qa",
"table",
"paper",
"book",
"laws",
"presentation",
"picture",
"one",
"knowledge_graph",
"email",
"tag"
}
if req.get("chunk_method") not in valid_chunk_method: if req.get("chunk_method") not in valid_chunk_method:
return get_error_data_result( return get_error_data_result(f"`chunk_method` {req['chunk_method']} doesn't exist")
f"`chunk_method` {req['chunk_method']} doesn't exist"
)
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
return get_error_data_result(message="Not supported yet!") return get_error_data_result(message="Not supported yet!")
@ -323,9 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
if not e: if not e:
return get_error_data_result(message="Document not found!") return get_error_data_result(message="Document not found!")
if not req.get("parser_config"): if not req.get("parser_config"):
req["parser_config"] = get_parser_config( req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config"))
req["chunk_method"], req.get("parser_config")
)
DocumentService.update_parser_config(doc.id, req["parser_config"]) DocumentService.update_parser_config(doc.id, req["parser_config"])
if doc.token_num > 0: if doc.token_num > 0:
e = DocumentService.increment_chunk_num( e = DocumentService.increment_chunk_num(
@ -343,19 +309,45 @@ def update_doc(tenant_id, dataset_id, document_id):
status = int(req["enabled"]) status = int(req["enabled"])
if doc.status != req["enabled"]: if doc.status != req["enabled"]:
try: try:
if not DocumentService.update_by_id( if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
doc.id, {"status": str(status)}): return get_error_data_result(message="Database error (Document update)!")
return get_error_data_result(
message="Database error (Document update)!")
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True) return get_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
return get_result() try:
ok, doc = DocumentService.get_by_id(doc.id)
if not ok:
return get_error_data_result(message="Dataset created failed")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "dataset_id",
"token_num": "token_count",
"parser_id": "chunk_method",
}
run_mapping = {
"0": "UNSTART",
"1": "RUNNING",
"2": "CANCEL",
"3": "DONE",
"4": "FAIL",
}
renamed_doc = {}
for key, value in doc.to_dict().items():
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
if key == "run":
renamed_doc["run"] = run_mapping.get(value)
return get_result(data=renamed_doc)
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821 @manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
@ -397,25 +389,17 @@ def download(tenant_id, dataset_id, document_id):
type: object type: object
""" """
if not document_id: if not document_id:
return get_error_data_result( return get_error_data_result(message="Specify document_id please.")
message="Specify document_id please."
)
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message=f"You do not own the dataset {dataset_id}.") return get_error_data_result(message=f"You do not own the dataset {dataset_id}.")
doc = DocumentService.query(kb_id=dataset_id, id=document_id) doc = DocumentService.query(kb_id=dataset_id, id=document_id)
if not doc: if not doc:
return get_error_data_result( return get_error_data_result(message=f"The dataset not own the document {document_id}.")
message=f"The dataset not own the document {document_id}."
)
# The process of downloading # The process of downloading
doc_id, doc_location = File2DocumentService.get_storage_address( doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
doc_id=document_id
) # minio address
file_stream = STORAGE_IMPL.get(doc_id, doc_location) file_stream = STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream: if not file_stream:
return construct_json_result( return construct_json_result(message="This file is empty.", code=settings.RetCode.DATA_ERROR)
message="This file is empty.", code=settings.RetCode.DATA_ERROR
)
file = BytesIO(file_stream) file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type # Use send_file with a proper filename and MIME type
return send_file( return send_file(
@ -530,9 +514,7 @@ def list_docs(dataset_id, tenant_id):
desc = False desc = False
else: else:
desc = True desc = True
docs, tol = DocumentService.get_list( docs, tol = DocumentService.get_list(dataset_id, page, page_size, orderby, desc, keywords, id, name)
dataset_id, page, page_size, orderby, desc, keywords, id, name
)
# rename key's name # rename key's name
renamed_doc_list = [] renamed_doc_list = []
@ -638,9 +620,7 @@ def delete(tenant_id, dataset_id):
b, n = File2DocumentService.get_storage_address(doc_id=doc_id) b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
if not DocumentService.remove_document(doc, tenant_id): if not DocumentService.remove_document(doc, tenant_id):
return get_error_data_result( return get_error_data_result(message="Database error (Document removal)!")
message="Database error (Document removal)!"
)
f2d = File2DocumentService.get_by_document_id(doc_id) f2d = File2DocumentService.get_by_document_id(doc_id)
FileService.filter_delete( FileService.filter_delete(
@ -664,7 +644,10 @@ def delete(tenant_id, dataset_id):
if duplicate_messages: if duplicate_messages:
if success_count > 0: if success_count > 0:
return get_result(message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages},) return get_result(
message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages},
)
else: else:
return get_error_data_result(message=";".join(duplicate_messages)) return get_error_data_result(message=";".join(duplicate_messages))
@ -729,9 +712,7 @@ def parse(tenant_id, dataset_id):
if not doc: if not doc:
return get_error_data_result(message=f"You don't own the document {id}.") return get_error_data_result(message=f"You don't own the document {id}.")
if 0.0 < doc[0].progress < 1.0: if 0.0 < doc[0].progress < 1.0:
return get_error_data_result( return get_error_data_result("Can't parse document that is currently being processed")
"Can't parse document that is currently being processed"
)
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
@ -746,7 +727,10 @@ def parse(tenant_id, dataset_id):
return get_result(message=f"Documents not found: {not_found}", code=settings.RetCode.DATA_ERROR) return get_result(message=f"Documents not found: {not_found}", code=settings.RetCode.DATA_ERROR)
if duplicate_messages: if duplicate_messages:
if success_count > 0: if success_count > 0:
return get_result(message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages},) return get_result(
message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages},
)
else: else:
return get_error_data_result(message=";".join(duplicate_messages)) return get_error_data_result(message=";".join(duplicate_messages))
@ -808,16 +792,17 @@ def stop_parsing(tenant_id, dataset_id):
if not doc: if not doc:
return get_error_data_result(message=f"You don't own the document {id}.") return get_error_data_result(message=f"You don't own the document {id}.")
if int(doc[0].progress) == 1 or doc[0].progress == 0: if int(doc[0].progress) == 1 or doc[0].progress == 0:
return get_error_data_result( return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
"Can't stop parsing document with progress at 0 or 1"
)
info = {"run": "2", "progress": 0, "chunk_num": 0} info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
success_count += 1 success_count += 1
if duplicate_messages: if duplicate_messages:
if success_count > 0: if success_count > 0:
return get_result(message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages},) return get_result(
message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages},
)
else: else:
return get_error_data_result(message=";".join(duplicate_messages)) return get_error_data_result(message=";".join(duplicate_messages))
return get_result() return get_result()
@ -906,9 +891,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
doc = DocumentService.query(id=document_id, kb_id=dataset_id) doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc: if not doc:
return get_error_data_result( return get_error_data_result(message=f"You don't own the document {document_id}.")
message=f"You don't own the document {document_id}."
)
doc = doc[0] doc = doc[0]
req = request.args req = request.args
doc_id = document_id doc_id = document_id
@ -956,34 +939,29 @@ def list_chunks(tenant_id, dataset_id, document_id):
del chunk[n] del chunk[n]
if not chunk: if not chunk:
return get_error_data_result(f"Chunk `{req.get('id')}` not found.") return get_error_data_result(f"Chunk `{req.get('id')}` not found.")
res['total'] = 1 res["total"] = 1
final_chunk = { final_chunk = {
"id":chunk.get("id",chunk.get("chunk_id")), "id": chunk.get("id", chunk.get("chunk_id")),
"content":chunk["content_with_weight"], "content": chunk["content_with_weight"],
"document_id":chunk.get("doc_id",chunk.get("document_id")), "document_id": chunk.get("doc_id", chunk.get("document_id")),
"docnm_kwd":chunk["docnm_kwd"], "docnm_kwd": chunk["docnm_kwd"],
"important_keywords":chunk.get("important_kwd",[]), "important_keywords": chunk.get("important_kwd", []),
"questions":chunk.get("question_kwd",[]), "questions": chunk.get("question_kwd", []),
"dataset_id":chunk.get("kb_id",chunk.get("dataset_id")), "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
"image_id":chunk.get("img_id", ""), "image_id": chunk.get("img_id", ""),
"available":bool(chunk.get("available_int",1)), "available": bool(chunk.get("available_int", 1)),
"positions":chunk.get("position_int",[]), "positions": chunk.get("position_int", []),
} }
res["chunks"].append(final_chunk) res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk) _ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
highlight=True)
res["total"] = sres.total res["total"] = sres.total
for id in sres.ids: for id in sres.ids:
d = { d = {
"id": id, "id": id,
"content": ( "content": (rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get("content_with_weight", "")),
rmSpace(sres.highlight[id])
if question and id in sres.highlight
else sres.field[id].get("content_with_weight", "")
),
"document_id": sres.field[id]["doc_id"], "document_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"], "docnm_kwd": sres.field[id]["docnm_kwd"],
"important_keywords": sres.field[id].get("important_kwd", []), "important_keywords": sres.field[id].get("important_kwd", []),
@ -991,10 +969,10 @@ def list_chunks(tenant_id, dataset_id, document_id):
"dataset_id": sres.field[id].get("kb_id", sres.field[id].get("dataset_id")), "dataset_id": sres.field[id].get("kb_id", sres.field[id].get("dataset_id")),
"image_id": sres.field[id].get("img_id", ""), "image_id": sres.field[id].get("img_id", ""),
"available": bool(int(sres.field[id].get("available_int", "1"))), "available": bool(int(sres.field[id].get("available_int", "1"))),
"positions": sres.field[id].get("position_int",[]), "positions": sres.field[id].get("position_int", []),
} }
res["chunks"].append(d) res["chunks"].append(d)
_ = Chunk(**d) # validate the chunk _ = Chunk(**d) # validate the chunk
return get_result(data=res) return get_result(data=res)
@ -1070,23 +1048,17 @@ def add_chunk(tenant_id, dataset_id, document_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
doc = DocumentService.query(id=document_id, kb_id=dataset_id) doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc: if not doc:
return get_error_data_result( return get_error_data_result(message=f"You don't own the document {document_id}.")
message=f"You don't own the document {document_id}."
)
doc = doc[0] doc = doc[0]
req = request.json req = request.json
if not str(req.get("content", "")).strip(): if not str(req.get("content", "")).strip():
return get_error_data_result(message="`content` is required") return get_error_data_result(message="`content` is required")
if "important_keywords" in req: if "important_keywords" in req:
if not isinstance(req["important_keywords"], list): if not isinstance(req["important_keywords"], list):
return get_error_data_result( return get_error_data_result("`important_keywords` is required to be a list")
"`important_keywords` is required to be a list"
)
if "questions" in req: if "questions" in req:
if not isinstance(req["questions"], list): if not isinstance(req["questions"], list):
return get_error_data_result( return get_error_data_result("`questions` is required to be a list")
"`questions` is required to be a list"
)
chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest() chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest()
d = { d = {
"id": chunk_id, "id": chunk_id,
@ -1095,22 +1067,16 @@ def add_chunk(tenant_id, dataset_id, document_id):
} }
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_keywords", []) d["important_kwd"] = req.get("important_keywords", [])
d["important_tks"] = rag_tokenizer.tokenize( d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", [])))
" ".join(req.get("important_keywords", []))
)
d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()] d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()]
d["question_tks"] = rag_tokenizer.tokenize( d["question_tks"] = rag_tokenizer.tokenize("\n".join(req.get("questions", [])))
"\n".join(req.get("questions", []))
)
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
d["kb_id"] = dataset_id d["kb_id"] = dataset_id
d["docnm_kwd"] = doc.name d["docnm_kwd"] = doc.name
d["doc_id"] = document_id d["doc_id"] = document_id
embd_id = DocumentService.get_embd_id(document_id) embd_id = DocumentService.get_embd_id(document_id)
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value, embd_id)
tenant_id, LLMType.EMBEDDING.value, embd_id
)
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
@ -1203,7 +1169,10 @@ def rm_chunk(tenant_id, dataset_id, document_id):
return get_result(message=f"deleted {chunk_number} chunks") return get_result(message=f"deleted {chunk_number} chunks")
return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}") return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}")
if duplicate_messages: if duplicate_messages:
return get_result(message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors", data={"success_count": chunk_number, "errors": duplicate_messages},) return get_result(
message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors",
data={"success_count": chunk_number, "errors": duplicate_messages},
)
return get_result(message=f"deleted {chunk_number} chunks") return get_result(message=f"deleted {chunk_number} chunks")
@ -1271,9 +1240,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
doc = DocumentService.query(id=document_id, kb_id=dataset_id) doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc: if not doc:
return get_error_data_result( return get_error_data_result(message=f"You don't own the document {document_id}.")
message=f"You don't own the document {document_id}."
)
doc = doc[0] doc = doc[0]
req = request.json req = request.json
if "content" in req: if "content" in req:
@ -1296,19 +1263,13 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
if "available" in req: if "available" in req:
d["available_int"] = int(req["available"]) d["available_int"] = int(req["available"])
embd_id = DocumentService.get_embd_id(document_id) embd_id = DocumentService.get_embd_id(document_id)
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value, embd_id)
tenant_id, LLMType.EMBEDDING.value, embd_id
)
if doc.parser_id == ParserType.QA: if doc.parser_id == ParserType.QA:
arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1] arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1]
if len(arr) != 2: if len(arr) != 2:
return get_error_data_result( return get_error_data_result(message="Q&A must be separated by TAB/ENTER key.")
message="Q&A must be separated by TAB/ENTER key."
)
q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
d = beAdoc( d = beAdoc(d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a]))
d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a])
)
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
@ -1425,9 +1386,7 @@ def retrieval_test(tenant_id):
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
for doc_id in doc_ids: for doc_id in doc_ids:
if doc_id not in doc_ids_list: if doc_id not in doc_ids_list:
return get_error_data_result( return get_error_data_result(f"The datasets don't own the document {doc_id}")
f"The datasets don't own the document {doc_id}"
)
similarity_threshold = float(req.get("similarity_threshold", 0.2)) similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
@ -1463,14 +1422,10 @@ def retrieval_test(tenant_id):
doc_ids, doc_ids,
rerank_mdl=rerank_mdl, rerank_mdl=rerank_mdl,
highlight=highlight, highlight=highlight,
rank_feature=label_question(question, kbs) rank_feature=label_question(question, kbs),
) )
if use_kg: if use_kg:
ck = settings.kg_retrievaler.retrieval(question, ck = settings.kg_retrievaler.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
[k.tenant_id for k in kbs],
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)
@ -1487,7 +1442,7 @@ def retrieval_test(tenant_id):
"important_kwd": "important_keywords", "important_kwd": "important_keywords",
"question_kwd": "questions", "question_kwd": "questions",
"docnm_kwd": "document_keyword", "docnm_kwd": "document_keyword",
"kb_id":"dataset_id" "kb_id": "dataset_id",
} }
rename_chunk = {} rename_chunk = {}
for key, value in chunk.items(): for key, value in chunk.items():

View File

@ -15,6 +15,7 @@
# #
import json import json
from .base import Base from .base import Base
from .chunk import Chunk from .chunk import Chunk
@ -52,12 +53,14 @@ class Document(Base):
if "meta_fields" in update_message: if "meta_fields" in update_message:
if not isinstance(update_message["meta_fields"], dict): if not isinstance(update_message["meta_fields"], dict):
raise Exception("meta_fields must be a dictionary") raise Exception("meta_fields must be a dictionary")
res = self.put(f'/datasets/{self.dataset_id}/documents/{self.id}', res = self.put(f"/datasets/{self.dataset_id}/documents/{self.id}", update_message)
update_message)
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
raise Exception(res["message"]) raise Exception(res["message"])
self._update_from_dict(self.rag, res.get("data", {}))
return self
def download(self): def download(self):
res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}") res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}")
try: try:
@ -66,9 +69,9 @@ class Document(Base):
except json.JSONDecodeError: except json.JSONDecodeError:
return res.content return res.content
def list_chunks(self, page=1, page_size=30, keywords="", id = ""): def list_chunks(self, page=1, page_size=30, keywords="", id=""):
data = {"keywords": keywords, "page": page, "page_size": page_size, "id": id} data = {"keywords": keywords, "page": page, "page_size": page_size, "id": id}
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data) res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", data)
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
chunks = [] chunks = []
@ -79,8 +82,7 @@ class Document(Base):
raise Exception(res.get("message")) raise Exception(res.get("message"))
def add_chunk(self, content: str, important_keywords: list[str] = [], questions: list[str] = []): def add_chunk(self, content: str, important_keywords: list[str] = [], questions: list[str] = []):
res = self.post(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', res = self.post(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"content": content, "important_keywords": important_keywords, "questions": questions})
{"content": content, "important_keywords": important_keywords, "questions": questions})
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
return Chunk(self.rag, res["data"].get("chunk")) return Chunk(self.rag, res["data"].get("chunk"))