add auto keywords and auto-question (#2965)

### What problem does this PR solve?

#2687

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2024-10-22 13:12:49 +08:00
committed by GitHub
parent 5aa9d7787e
commit 226bdd6e99
8 changed files with 119 additions and 61 deletions

View File

@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource
from api.db.db_models import APIToken, Task, File from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat from api.db.services.dialog_service import DialogService, chat, keyword_extraction
from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
@ -38,7 +38,6 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
generate_confirmation_token generate_confirmation_token
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from rag.nlp import keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService

View File

@ -21,8 +21,9 @@ from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from api.db.services.dialog_service import keyword_extraction
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer, keyword_extraction from rag.nlp import search, rag_tokenizer
from rag.utils.es_conn import ELASTICSEARCH from rag.utils.es_conn import ELASTICSEARCH
from rag.utils import rmSpace from rag.utils import rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType

View File

@ -16,16 +16,15 @@
from flask import request from flask import request
from api.db import StatusEnum from api.db import StatusEnum
from api.db.db_models import TenantLLM
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result from api.utils.api_utils import get_result
@manager.route('/chat', methods=['POST']) @manager.route('/chat', methods=['POST'])
@token_required @token_required
def create(tenant_id): def create(tenant_id):

View File

@ -1,10 +1,25 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request, jsonify from flask import request, jsonify
from db import LLMType, ParserType from api.db import LLMType, ParserType
from db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from settings import retrievaler, kg_retrievaler, RetCode from api.settings import retrievaler, kg_retrievaler, RetCode
from utils.api_utils import validate_request, build_error_result, apikey_required from api.utils.api_utils import validate_request, build_error_result, apikey_required
@manager.route('/dify/retrieval', methods=['POST']) @manager.route('/dify/retrieval', methods=['POST'])

View File

@ -1,48 +1,37 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib import pathlib
import re
import datetime import datetime
import json
import traceback
from botocore.docs.method import document_model_driven_method
from flask import request
from flask_login import login_required, current_user
from elasticsearch_dsl import Q
from pygments import highlight
from sphinx.addnodes import document
from api.db.services.dialog_service import keyword_extraction
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer, keyword_extraction from rag.nlp import rag_tokenizer
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils import rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService from api.settings import kg_retrievaler
from api.utils.api_utils import server_error_response, get_error_data_result, validate_request
from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler, kg_retrievaler
from api.utils.api_utils import get_result
import hashlib import hashlib
import re import re
from api.utils.api_utils import get_result, token_required, get_error_data_result from api.utils.api_utils import token_required
from api.db.db_models import Task
from api.db.db_models import Task, File
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import server_error_response
from api.utils.api_utils import get_result, get_error_data_result
from api.utils.api_utils import server_error_response, get_error_data_result, validate_request
from api.utils.api_utils import get_result, get_result, get_error_data_result
from functools import partial
from io import BytesIO from io import BytesIO
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from flask import request, send_file from flask import request, send_file
from flask_login import login_required
from api.db import FileSource, TaskStatus, FileType from api.db import FileSource, TaskStatus, FileType
from api.db.db_models import File from api.db.db_models import File
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -50,8 +39,7 @@ from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import RetCode, retrievaler from api.settings import RetCode, retrievaler
from api.utils.api_utils import construct_json_result, construct_error_response from api.utils.api_utils import construct_json_result
from rag.app import book, laws, manual, naive, one, paper, presentation, qa, resume, table, picture, audio, email
from rag.nlp import search from rag.nlp import search
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.es_conn import ELASTICSEARCH from rag.utils.es_conn import ELASTICSEARCH
@ -365,7 +353,6 @@ def list_chunks(tenant_id,dataset_id,document_id):
return get_result(data=res) return get_result(data=res)
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST']) @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST'])
@token_required @token_required
def create(tenant_id,dataset_id,document_id): def create(tenant_id,dataset_id,document_id):
@ -454,7 +441,6 @@ def rm_chunk(tenant_id,dataset_id,document_id):
return get_result() return get_result()
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT']) @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT'])
@token_required @token_required
def update_chunk(tenant_id,dataset_id,document_id,chunk_id): def update_chunk(tenant_id,dataset_id,document_id,chunk_id):
@ -512,7 +498,6 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id):
return get_result() return get_result()
@manager.route('/retrieval', methods=['POST']) @manager.route('/retrieval', methods=['POST'])
@token_required @token_required
def retrieval_test(tenant_id): def retrieval_test(tenant_id):

View File

@ -28,7 +28,6 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import chat_logger, retrievaler, kg_retrievaler from api.settings import chat_logger, retrievaler, kg_retrievaler
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.nlp import keyword_extraction
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
@ -80,6 +79,7 @@ class ConversationService(CommonService):
return list(sessions.dicts()) return list(sessions.dicts())
def message_fit_in(msg, max_length=4000): def message_fit_in(msg, max_length=4000):
def count(): def count():
nonlocal msg nonlocal msg
@ -456,6 +456,58 @@ def rewrite(tenant_id, llm_id, question):
return ans return ans
def keyword_extraction(chat_mdl, content, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: extract the most important keywords/phrases of a given piece of text content.
Requirements:
- Summarize the text content, and give top {topn} important keywords/phrases.
- The keywords MUST be in language of the given piece of text content.
- The keywords are delimited by ENGLISH COMMA.
- Keywords ONLY in output.
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >=0: return ""
return kwd
def question_proposal(chat_mdl, content, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: propose {topn} questions about a given piece of text content.
Requirements:
- Understand and summarize the text content, and propose top {topn} important questions.
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
- The questions MUST be in language of the given piece of text content.
- One question per line.
- Question ONLY in output.
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >= 0: return ""
return kwd
def full_question(tenant_id, llm_id, messages): def full_question(tenant_id, llm_id, messages):
if llm_id2llm_type(llm_id) == "image2text": if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)

View File

@ -570,14 +570,3 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。"):
return cks, images return cks, images
def keyword_extraction(chat_mdl, content):
prompt = """
You're a question analyzer.
1. Please give me the most important keyword/phrase of this question.
Answer format: (in language of user's question)
- keyword:
"""
kwd = chat_mdl.chat(prompt, [{"role": "user", "content": content}], {"temperature": 0.2})
if isinstance(kwd, tuple): return kwd[0]
return kwd

View File

@ -34,6 +34,7 @@ import pandas as pd
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.dialog_service import keyword_extraction, question_proposal
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
@ -198,6 +199,23 @@ def build(row):
d["_id"] = md5.hexdigest() d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if row["parser_config"].get("auto_keywords", 0):
chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
d["important_kwd"] = keyword_extraction(chat_mdl, ck["content_with_weight"],
row["parser_config"]["auto_keywords"]).split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
if row["parser_config"].get("auto_questions", 0):
chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
qst = question_proposal(chat_mdl, ck["content_with_weight"], row["parser_config"]["auto_keywords"])
ck["content_with_weight"] = f"Question: \n{qst}\n\nAnswer:\n" + ck["content_with_weight"]
qst = rag_tokenizer.tokenize(qst)
if "content_ltks" in ck:
ck["content_ltks"] += " " + qst
if "content_sm_ltks" in ck:
ck["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
if not d.get("image"): if not d.get("image"):
docs.append(d) docs.append(d)
continue continue