diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 0a88f3027..9337bd6bf 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -27,6 +27,8 @@ from groq import Groq import os import json import requests +import asyncio +from rag.svr.jina_server import Prompt,Generation class Base(ABC): def __init__(self, key, model_name, base_url): @@ -381,8 +383,10 @@ class LocalLLM(Base): def __conn(self): from multiprocessing.connection import Client + self._connection = Client( - (self.host, self.port), authkey=b'infiniflow-token4kevinhu') + (self.host, self.port), authkey=b"infiniflow-token4kevinhu" + ) def __getattr__(self, name): import pickle @@ -390,8 +394,7 @@ class LocalLLM(Base): def do_rpc(*args, **kwargs): for _ in range(3): try: - self._connection.send( - pickle.dumps((name, args, kwargs))) + self._connection.send(pickle.dumps((name, args, kwargs))) return pickle.loads(self._connection.recv()) except Exception as e: self.__conn() @@ -399,35 +402,45 @@ class LocalLLM(Base): return do_rpc - def __init__(self, key, model_name="glm-3-turbo"): - self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) + def __init__(self, key, model_name): + from jina import Client - def chat(self, system, history, gen_conf): + self.client = Client(port=12345, protocol="grpc", asyncio=True) + + def _prepare_prompt(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) - try: - ans = self.client.chat( - history, - gen_conf - ) - return ans, num_tokens_from_string(ans) - except Exception as e: - return "**ERROR**: " + str(e), 0 + if "max_tokens" in gen_conf: + gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens") + return Prompt(message=history, gen_conf=gen_conf) - def chat_streamly(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - token_count = 0 + def _stream_response(self, endpoint, prompt): answer = "" try: - for ans in self.client.chat_streamly(history, gen_conf): - answer += ans - token_count += 1 - yield answer + res = self.client.stream_doc( + on=endpoint, inputs=prompt, return_type=Generation + ) + loop = asyncio.get_event_loop() + try: + while True: + answer = loop.run_until_complete(res.__anext__()).text + yield answer + except StopAsyncIteration: + pass except Exception as e: yield answer + "\n**ERROR**: " + str(e) + yield num_tokens_from_string(answer) - yield token_count + def chat(self, system, history, gen_conf): + prompt = self._prepare_prompt(system, history, gen_conf) + chat_gen = self._stream_response("/chat", prompt) + ans = next(chat_gen) + total_tokens = next(chat_gen) + return ans, total_tokens + + def chat_streamly(self, system, history, gen_conf): + prompt = self._prepare_prompt(system, history, gen_conf) + return self._stream_response("/stream", prompt) class VolcEngineChat(Base): diff --git a/rag/svr/jina_server.py b/rag/svr/jina_server.py new file mode 100644 index 000000000..9dba8b55e --- /dev/null +++ b/rag/svr/jina_server.py @@ -0,0 +1,93 @@ +from jina import Deployment +from docarray import BaseDoc +from jina import Executor, requests +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +import argparse +import torch + + +class Prompt(BaseDoc): + message: list[dict] + gen_conf: dict + + +class Generation(BaseDoc): + text: str + + +tokenizer = None +model_name = "" + + +class TokenStreamingExecutor(Executor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", torch_dtype="auto" + ) + + @requests(on="/chat") + async def generate(self, doc: Prompt, **kwargs) -> Generation: + text = tokenizer.apply_chat_template( + doc.message, + tokenize=False, + ) + inputs = tokenizer([text], return_tensors="pt") + generation_config = GenerationConfig( + **doc.gen_conf, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.eos_token_id + ) + generated_ids = self.model.generate( + inputs.input_ids, generation_config=generation_config + ) + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(inputs.input_ids, generated_ids) + ] + + response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + yield Generation(text=response) + + @requests(on="/stream") + async def task(self, doc: Prompt, **kwargs) -> Generation: + text = tokenizer.apply_chat_template( + doc.message, + tokenize=False, + ) + input = tokenizer([text], return_tensors="pt") + input_len = input["input_ids"].shape[1] + max_new_tokens = 512 + if "max_new_tokens" in doc.gen_conf: + max_new_tokens = doc.gen_conf.pop("max_new_tokens") + generation_config = GenerationConfig( + **doc.gen_conf, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.eos_token_id + ) + for _ in range(max_new_tokens): + output = self.model.generate( + **input, max_new_tokens=1, generation_config=generation_config + ) + if output[0][-1] == tokenizer.eos_token_id: + break + yield Generation( + text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True) + ) + input = { + "input_ids": output, + "attention_mask": torch.ones(1, len(output[0])), + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, help="Model name or path") + parser.add_argument("--port", default=12345, type=int, help="Jina serving port") + args = parser.parse_args() + model_name = args.model_name + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + with Deployment( + uses=TokenStreamingExecutor, port=args.port, protocol="grpc" + ) as dep: + dep.block()