apply pep8 formalize (#155)

This commit is contained in:
KevinHuSh
2024-03-27 11:33:46 +08:00
committed by GitHub
parent a02e836790
commit fd7fcb5baf
55 changed files with 1568 additions and 753 deletions

View File

@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
class RPCHandler:
def __init__(self):
self._functions = { }
self._functions = {}
def register_function(self, func):
self._functions[func.__name__] = func
@ -21,12 +21,12 @@ class RPCHandler:
func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response
try:
r = self._functions[func_name](*args,**kwargs)
r = self._functions[func_name](*args, **kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
pass
def rpc_server(hdlr, address, authkey):
@ -44,11 +44,17 @@ def rpc_server(hdlr, address, authkey):
models = []
tokenizer = None
def chat(messages, gen_conf):
global tokenizer
model = Model()
try:
conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))}
conf = {
"max_new_tokens": int(
gen_conf.get(
"max_tokens", 256)), "temperature": float(
gen_conf.get(
"temperature", 0.1))}
print(messages, conf)
text = tokenizer.apply_chat_template(
messages,
@ -65,7 +71,8 @@ def chat(messages, gen_conf):
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return tokenizer.batch_decode(
generated_ids, skip_special_tokens=True)[0]
except Exception as e:
return str(e)
@ -75,10 +82,15 @@ def Model():
random.seed(time.time())
return random.choice(models)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
parser.add_argument(
"--port",
default=7860,
type=int,
help="RPC serving port")
args = parser.parse_args()
handler = RPCHandler()
@ -93,4 +105,5 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Run the server
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
rpc_server(handler, ('0.0.0.0', args.port),
authkey=b'infiniflow-token4kevinhu')