add local llm implementation (#119)

This commit is contained in:
KevinHuSh
2024-03-12 11:57:08 +08:00
committed by GitHub
parent 0452a6db73
commit f1f09df901
17 changed files with 196 additions and 25 deletions

View File

@ -19,22 +19,25 @@ from .cv_model import *
EmbeddingModel = {
"Infiniflow": HuEmbedding,
"local": HuEmbedding,
"OpenAI": OpenAIEmbed,
"通义千问": HuEmbedding, #QWenEmbed,
"智谱AI": ZhipuEmbed
}
CvModel = {
"OpenAI": GptV4,
"Infiniflow": GptV4,
"local": LocalCV,
"通义千问": QWenCV,
"智谱AI": Zhipu4V
}
ChatModel = {
"OpenAI": GptTurbo,
"Infiniflow": GptTurbo,
"智谱AI": ZhipuChat,
"通义千问": QWenChat,
"local": LocalLLM
}

View File

@ -20,6 +20,7 @@ from openai import OpenAI
import openai
from rag.nlp import is_english
from rag.utils import num_tokens_from_string
class Base(ABC):
@ -86,7 +87,6 @@ class ZhipuChat(Base):
self.model_name = model_name
def chat(self, system, history, gen_conf):
from http import HTTPStatus
if system: history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
@ -100,4 +100,42 @@ class ZhipuChat(Base):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.completion_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
return "**ERROR**: " + str(e), 0
class LocalLLM(Base):
class RPCProxy:
def __init__(self, host, port):
self.host = host
self.port = int(port)
self.__conn()
def __conn(self):
from multiprocessing.connection import Client
self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
def __getattr__(self, name):
import pickle
def do_rpc(*args, **kwargs):
for _ in range(3):
try:
self._connection.send(pickle.dumps((name, args, kwargs)))
return pickle.loads(self._connection.recv())
except Exception as e:
self.__conn()
raise Exception("RPC connection lost!")
return do_rpc
def __init__(self, key, model_name="glm-3-turbo"):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system})
try:
ans = self.client.chat(
history,
gen_conf
)
return ans, num_tokens_from_string(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0

View File

@ -138,3 +138,11 @@ class Zhipu4V(Base):
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
pass
def describe(self, image, max_tokens=1024):
return "", 0

90
rag/llm/rpc_server.py Normal file
View File

@ -0,0 +1,90 @@
import argparse
import pickle
import random
import time
from multiprocessing.connection import Listener
from threading import Thread
import torch
class RPCHandler:
def __init__(self):
self._functions = { }
def register_function(self, func):
self._functions[func.__name__] = func
def handle_connection(self, connection):
try:
while True:
# Receive a message
func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response
try:
r = self._functions[func_name](*args,**kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
def rpc_server(hdlr, address, authkey):
sock = Listener(address, authkey=authkey)
while True:
try:
client = sock.accept()
t = Thread(target=hdlr.handle_connection, args=(client,))
t.daemon = True
t.start()
except Exception as e:
print("【EXCEPTION】:", str(e))
models = []
tokenizer = None
def chat(messages, gen_conf):
global tokenizer
model = Model()
roles = {"system":"System", "user": "User", "assistant": "Assistant"}
line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
line = "\n".join(line) + "\nAssistant: "
tokens = tokenizer([line], return_tensors='pt')
tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
tokens.keys()}
res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
return res.split("Assistant: ")[-1]
def Model():
global models
random.seed(time.time())
return random.choice(models)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
args = parser.parse_args()
handler = RPCHandler()
handler.register_function(chat)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
models = []
for _ in range(2):
m = AutoModelForCausalLM.from_pretrained(args.model_name,
device_map="auto",
torch_dtype='auto',
trust_remote_code=True)
m.generation_config = GenerationConfig.from_pretrained(args.model_name)
m.generation_config.pad_token_id = m.generation_config.eos_token_id
models.append(m)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
trust_remote_code=True)
# Run the server
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')