diff --git a/README.md b/README.md index 686f6bb42..28b9692b9 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,7 @@ $ docker compose up -d ## 🆕 Latest Features +- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment. - 2024-04-10 Add a new layout recognize model for method 'Laws'. - 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment. - 2024-04-07 Support Chinese UI. diff --git a/README_ja.md b/README_ja.md index d3074bd14..3197279e7 100644 --- a/README_ja.md +++ b/README_ja.md @@ -171,6 +171,8 @@ $ docker compose up -d ``` ## 🆕 最新の新機能 + +- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。 - 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。 - 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。 - 2024-04-07 中国語インターフェースをサポートします。 diff --git a/README_zh.md b/README_zh.md index 7c7d571b4..143fcc769 100644 --- a/README_zh.md +++ b/README_zh.md @@ -172,6 +172,7 @@ $ docker compose up -d ## 🆕 最近新特性 +- 2024-04-11 支持用 [Xinference](./docs/xinference.md) for local LLM deployment. - 2024-04-10 为‘Laws’版面分析增加了模型。 - 2024-04-08 支持用 [Ollama](./docs/ollama.md) 对大模型进行本地化部署。 - 2024-04-07 支持中文界面。 diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 02f629ac5..5ee940f46 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -22,6 +22,7 @@ from werkzeug.wrappers.request import Request from flask_cors import CORS from api.db import StatusEnum +from api.db.db_models import close_connection from api.db.services import UserService from api.utils import CustomJSONEncoder @@ -42,7 +43,7 @@ for h in access_logger.handlers: Request.json = property(lambda self: self.get_json(force=True, silent=True)) app = Flask(__name__) -CORS(app, supports_credentials=True,max_age = 2592000) +CORS(app, supports_credentials=True,max_age=2592000) app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) @@ -94,8 +95,6 @@ client_urls_prefix = [ ] - - @login_manager.request_loader def load_user(web_request): jwt = Serializer(secret_key=SECRET_KEY) @@ -112,4 +111,9 @@ def load_user(web_request): stat_logger.exception(e) return None else: - return None \ No newline at end of file + return None + + +@app.teardown_request +def _db_close(exc): + close_connection() \ No newline at end of file diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index fd9c266ea..8c42c804b 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -360,6 +360,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl): "|" for r in tbl["rows"]] rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) + if not docid_idx or not docnm_idx: chat_logger.warning("SQL missing field: " + sql) return { diff --git a/api/db/init_data.py b/api/db/init_data.py index 4cc72a2d5..2e5026af4 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -109,6 +109,12 @@ factory_infos = [{ "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", +}, + { + "name": "Xinference", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", }, # { # "name": "文心一言", diff --git a/docker/docker-compose-CN.yml b/docker/docker-compose-CN.yml index 262163420..a4f3f77c3 100644 --- a/docker/docker-compose-CN.yml +++ b/docker/docker-compose-CN.yml @@ -20,7 +20,6 @@ services: - 443:443 volumes: - ./service_conf.yaml:/ragflow/conf/service_conf.yaml - - ./entrypoint.sh:/ragflow/entrypoint.sh - ./ragflow-logs:/ragflow/logs - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf - ./nginx/proxy.conf:/etc/nginx/proxy.conf diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index f5ad8f8b2..312b5329e 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -19,7 +19,6 @@ services: - 443:443 volumes: - ./service_conf.yaml:/ragflow/conf/service_conf.yaml - - ./entrypoint.sh:/ragflow/entrypoint.sh - ./ragflow-logs:/ragflow/logs - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf - ./nginx/proxy.conf:/etc/nginx/proxy.conf diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index c3fc7db81..c088a7f94 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -21,6 +21,7 @@ from .cv_model import * EmbeddingModel = { "Ollama": OllamaEmbed, "OpenAI": OpenAIEmbed, + "Xinference": XinferenceEmbed, "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "ZHIPU-AI": ZhipuEmbed, "Moonshot": HuEmbedding @@ -30,6 +31,7 @@ EmbeddingModel = { CvModel = { "OpenAI": GptV4, "Ollama": OllamaCV, + "Xinference": XinferenceCV, "Tongyi-Qianwen": QWenCV, "ZHIPU-AI": Zhipu4V, "Moonshot": LocalCV @@ -41,6 +43,7 @@ ChatModel = { "ZHIPU-AI": ZhipuChat, "Tongyi-Qianwen": QWenChat, "Ollama": OllamaChat, + "Xinference": XinferenceChat, "Moonshot": MoonshotChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index d4f0e7b64..b9bb36d73 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -158,6 +158,28 @@ class OllamaChat(Base): return "**ERROR**: " + str(e), 0 +class XinferenceChat(Base): + def __init__(self, key=None, model_name="", base_url=""): + self.client = OpenAI(api_key="xxx", base_url=base_url) + self.model_name = model_name + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + **gen_conf) + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, response.usage.completion_tokens + except openai.APIError as e: + return "**ERROR**: " + str(e), 0 + + class LocalLLM(Base): class RPCProxy: def __init__(self, host, port): diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index d764bc873..4b966991b 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -161,6 +161,22 @@ class OllamaCV(Base): except Exception as e: return "**ERROR**: " + str(e), 0 +class XinferenceCV(Base): + def __init__(self, key, model_name="", lang="Chinese", base_url=""): + self.client = OpenAI(api_key=key, base_url=base_url) + self.model_name = model_name + self.lang = lang + + def describe(self, image, max_tokens=300): + b64 = self.image2base64(image) + + res = self.client.chat.completions.create( + model=self.model_name, + messages=self.prompt(b64), + max_tokens=max_tokens, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens + class LocalCV(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index d5b763d18..aa6b565b8 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -170,3 +170,20 @@ class OllamaEmbed(Base): res = self.client.embeddings(prompt=text, model=self.model_name) return np.array(res["embedding"]), 128 + + +class XinferenceEmbed(Base): + def __init__(self, key, model_name="", base_url=""): + self.client = OpenAI(api_key="xxx", base_url=base_url) + self.model_name = model_name + + def encode(self, texts: list, batch_size=32): + res = self.client.embeddings.create(input=texts, + model=self.model_name) + return np.array([d.embedding for d in res.data] + ), res.usage.total_tokens + + def encode_queries(self, text): + res = self.client.embeddings.create(input=[text], + model=self.model_name) + return np.array(res.data[0].embedding), res.usage.total_tokens diff --git a/rag/settings.py b/rag/settings.py index f84831df8..da022628f 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -34,7 +34,7 @@ LoggerFactory.set_directory( "logs", "rag")) # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} -LoggerFactory.LEVEL = 10 +LoggerFactory.LEVEL = 30 es_logger = getLogger("es") minio_logger = getLogger("minio") diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 1f5e37af2..6ea80d9c9 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -24,6 +24,8 @@ import sys import time import traceback from functools import partial + +from api.db.db_models import close_connection from rag.settings import database_logger from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from multiprocessing import Pool @@ -302,3 +304,4 @@ if __name__ == "__main__": comm = MPI.COMM_WORLD while True: main(int(sys.argv[2]), int(sys.argv[1])) + close_connection()