mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
apply pep8 formalize (#155)
This commit is contained in:
@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
class RPCHandler:
|
||||
def __init__(self):
|
||||
self._functions = { }
|
||||
self._functions = {}
|
||||
|
||||
def register_function(self, func):
|
||||
self._functions[func.__name__] = func
|
||||
@ -21,12 +21,12 @@ class RPCHandler:
|
||||
func_name, args, kwargs = pickle.loads(connection.recv())
|
||||
# Run the RPC and send a response
|
||||
try:
|
||||
r = self._functions[func_name](*args,**kwargs)
|
||||
r = self._functions[func_name](*args, **kwargs)
|
||||
connection.send(pickle.dumps(r))
|
||||
except Exception as e:
|
||||
connection.send(pickle.dumps(e))
|
||||
except EOFError:
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def rpc_server(hdlr, address, authkey):
|
||||
@ -44,11 +44,17 @@ def rpc_server(hdlr, address, authkey):
|
||||
models = []
|
||||
tokenizer = None
|
||||
|
||||
|
||||
def chat(messages, gen_conf):
|
||||
global tokenizer
|
||||
model = Model()
|
||||
try:
|
||||
conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))}
|
||||
conf = {
|
||||
"max_new_tokens": int(
|
||||
gen_conf.get(
|
||||
"max_tokens", 256)), "temperature": float(
|
||||
gen_conf.get(
|
||||
"temperature", 0.1))}
|
||||
print(messages, conf)
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
@ -65,7 +71,8 @@ def chat(messages, gen_conf):
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
return tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True)[0]
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
@ -75,10 +82,15 @@ def Model():
|
||||
random.seed(time.time())
|
||||
return random.choice(models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, help="Model name")
|
||||
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
default=7860,
|
||||
type=int,
|
||||
help="RPC serving port")
|
||||
args = parser.parse_args()
|
||||
|
||||
handler = RPCHandler()
|
||||
@ -93,4 +105,5 @@ if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
|
||||
# Run the server
|
||||
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
|
||||
rpc_server(handler, ('0.0.0.0', args.port),
|
||||
authkey=b'infiniflow-token4kevinhu')
|
||||
|
||||
Reference in New Issue
Block a user