Add more commands to RAGFlow CLI (#12731)

### What problem does this PR solve?

This PR is going to make RAGFlow CLI to access RAGFlow as normal user,
and work as the a testing tool for RAGFlow server.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2026-01-21 18:49:52 +08:00
committed by GitHub
parent 6cd4fd91e6
commit 2e2c8f6ca9
8 changed files with 2387 additions and 1173 deletions

161
admin/client/http_client.py Normal file
View File

@ -0,0 +1,161 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import json
from typing import Any, Dict, Optional, Tuple
import requests
class HttpClient:
def __init__(
self,
host: str = "127.0.0.1",
port: int = 9381,
api_version: str = "v1",
api_key: Optional[str] = None,
connect_timeout: float = 5.0,
read_timeout: float = 60.0,
verify_ssl: bool = False,
) -> None:
self.host = host
self.port = port
self.api_version = api_version
self.api_key = api_key
self.login_token: str | None = None
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.verify_ssl = verify_ssl
def api_base(self) -> str:
return f"{self.host}:{self.port}/api/{self.api_version}"
def non_api_base(self) -> str:
return f"{self.host}:{self.port}/{self.api_version}"
def build_url(self, path: str, use_api_base: bool = True) -> str:
base = self.api_base() if use_api_base else self.non_api_base()
if self.verify_ssl:
return f"https://{base}/{path.lstrip('/')}"
else:
return f"http://{base}/{path.lstrip('/')}"
def _headers(self, auth_kind: Optional[str], extra: Optional[Dict[str, str]]) -> Dict[str, str]:
headers = {}
if auth_kind == "api" and self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
elif auth_kind == "web" and self.login_token:
headers["Authorization"] = self.login_token
elif auth_kind == "admin" and self.login_token:
headers["Authorization"] = self.login_token
else:
pass
if extra:
headers.update(extra)
return headers
def request(
self,
method: str,
path: str,
*,
use_api_base: bool = True,
auth_kind: Optional[str] = "api",
headers: Optional[Dict[str, str]] = None,
json_body: Optional[Dict[str, Any]] = None,
data: Any = None,
files: Any = None,
params: Optional[Dict[str, Any]] = None,
stream: bool = False,
iterations: int = 1,
) -> requests.Response | dict:
url = self.build_url(path, use_api_base=use_api_base)
merged_headers = self._headers(auth_kind, headers)
timeout: Tuple[float, float] = (self.connect_timeout, self.read_timeout)
if iterations > 1:
response_list = []
total_duration = 0.0
for _ in range(iterations):
start_time = time.perf_counter()
response = requests.request(
method=method,
url=url,
headers=merged_headers,
json=json_body,
data=data,
files=files,
params=params,
timeout=timeout,
stream=stream,
verify=self.verify_ssl,
)
end_time = time.perf_counter()
total_duration += end_time - start_time
response_list.append(response)
return {"duration": total_duration, "response_list": response_list}
else:
return requests.request(
method=method,
url=url,
headers=merged_headers,
json=json_body,
data=data,
files=files,
params=params,
timeout=timeout,
stream=stream,
verify=self.verify_ssl,
)
def request_json(
self,
method: str,
path: str,
*,
use_api_base: bool = True,
auth_kind: Optional[str] = "api",
headers: Optional[Dict[str, str]] = None,
json_body: Optional[Dict[str, Any]] = None,
data: Any = None,
files: Any = None,
params: Optional[Dict[str, Any]] = None,
stream: bool = False,
) -> Dict[str, Any]:
response = self.request(
method,
path,
use_api_base=use_api_base,
auth_kind=auth_kind,
headers=headers,
json_body=json_body,
data=data,
files=files,
params=params,
stream=stream,
)
try:
return response.json()
except Exception as exc:
raise ValueError(f"Non-JSON response from {path}: {exc}") from exc
@staticmethod
def parse_json_bytes(raw: bytes) -> Dict[str, Any]:
try:
return json.loads(raw.decode("utf-8"))
except Exception as exc:
raise ValueError(f"Invalid JSON payload: {exc}") from exc

609
admin/client/parser.py Normal file
View File

@ -0,0 +1,609 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from lark import Transformer
GRAMMAR = r"""
start: command
command: sql_command | meta_command
sql_command: list_services
| show_service
| startup_service
| shutdown_service
| restart_service
| register_user
| list_users
| show_user
| drop_user
| alter_user
| create_user
| activate_user
| list_datasets
| list_agents
| create_role
| drop_role
| alter_role
| list_roles
| show_role
| grant_permission
| revoke_permission
| alter_user_role
| show_user_permission
| show_version
| grant_admin
| revoke_admin
| set_variable
| show_variable
| list_variables
| list_configs
| list_environments
| generate_key
| list_keys
| drop_key
| show_current_user
| set_default_llm
| set_default_vlm
| set_default_embedding
| set_default_reranker
| set_default_asr
| set_default_tts
| reset_default_llm
| reset_default_vlm
| reset_default_embedding
| reset_default_reranker
| reset_default_asr
| reset_default_tts
| create_model_provider
| drop_model_provider
| create_user_dataset_with_parser
| create_user_dataset_with_pipeline
| drop_user_dataset
| list_user_datasets
| list_user_dataset_files
| list_user_agents
| list_user_chats
| create_user_chat
| drop_user_chat
| list_user_model_providers
| list_user_default_models
| parse_dataset_docs
| parse_dataset_sync
| parse_dataset_async
| import_docs_into_dataset
| search_on_datasets
| benchmark
// meta command definition
meta_command: "\\" meta_command_name [meta_args]
meta_command_name: /[a-zA-Z?]+/
meta_args: (meta_arg)+
meta_arg: /[^\\s"']+/ | quoted_string
// command definition
REGISTER: "REGISTER"i
LIST: "LIST"i
SERVICES: "SERVICES"i
SHOW: "SHOW"i
CREATE: "CREATE"i
SERVICE: "SERVICE"i
SHUTDOWN: "SHUTDOWN"i
STARTUP: "STARTUP"i
RESTART: "RESTART"i
USERS: "USERS"i
DROP: "DROP"i
USER: "USER"i
ALTER: "ALTER"i
ACTIVE: "ACTIVE"i
ADMIN: "ADMIN"i
PASSWORD: "PASSWORD"i
DATASET: "DATASET"i
DATASETS: "DATASETS"i
OF: "OF"i
AGENTS: "AGENTS"i
ROLE: "ROLE"i
ROLES: "ROLES"i
DESCRIPTION: "DESCRIPTION"i
GRANT: "GRANT"i
REVOKE: "REVOKE"i
ALL: "ALL"i
PERMISSION: "PERMISSION"i
TO: "TO"i
FROM: "FROM"i
FOR: "FOR"i
RESOURCES: "RESOURCES"i
ON: "ON"i
SET: "SET"i
RESET: "RESET"i
VERSION: "VERSION"i
VAR: "VAR"i
VARS: "VARS"i
CONFIGS: "CONFIGS"i
ENVS: "ENVS"i
KEY: "KEY"i
KEYS: "KEYS"i
GENERATE: "GENERATE"i
MODEL: "MODEL"i
MODELS: "MODELS"i
PROVIDER: "PROVIDER"i
PROVIDERS: "PROVIDERS"i
DEFAULT: "DEFAULT"i
CHATS: "CHATS"i
CHAT: "CHAT"i
FILES: "FILES"i
AS: "AS"i
PARSE: "PARSE"i
IMPORT: "IMPORT"i
INTO: "INTO"i
WITH: "WITH"i
PARSER: "PARSER"i
PIPELINE: "PIPELINE"i
SEARCH: "SEARCH"i
CURRENT: "CURRENT"i
LLM: "LLM"i
VLM: "VLM"i
EMBEDDING: "EMBEDDING"i
RERANKER: "RERANKER"i
ASR: "ASR"i
TTS: "TTS"i
ASYNC: "ASYNC"i
SYNC: "SYNC"i
BENCHMARK: "BENCHMARK"i
list_services: LIST SERVICES ";"
show_service: SHOW SERVICE NUMBER ";"
startup_service: STARTUP SERVICE NUMBER ";"
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
restart_service: RESTART SERVICE NUMBER ";"
register_user: REGISTER USER quoted_string AS quoted_string PASSWORD quoted_string ";"
list_users: LIST USERS ";"
drop_user: DROP USER quoted_string ";"
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
show_user: SHOW USER quoted_string ";"
create_user: CREATE USER quoted_string quoted_string ";"
activate_user: ALTER USER ACTIVE quoted_string status ";"
list_datasets: LIST DATASETS OF quoted_string ";"
list_agents: LIST AGENTS OF quoted_string ";"
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
drop_role: DROP ROLE identifier ";"
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
list_roles: LIST ROLES ";"
show_role: SHOW ROLE identifier ";"
grant_permission: GRANT identifier_list ON identifier TO ROLE identifier ";"
revoke_permission: REVOKE identifier_list ON identifier FROM ROLE identifier ";"
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
show_user_permission: SHOW USER PERMISSION quoted_string ";"
show_version: SHOW VERSION ";"
grant_admin: GRANT ADMIN quoted_string ";"
revoke_admin: REVOKE ADMIN quoted_string ";"
generate_key: GENERATE KEY FOR USER quoted_string ";"
list_keys: LIST KEYS OF quoted_string ";"
drop_key: DROP KEY quoted_string OF quoted_string ";"
set_variable: SET VAR identifier identifier ";"
show_variable: SHOW VAR identifier ";"
list_variables: LIST VARS ";"
list_configs: LIST CONFIGS ";"
list_environments: LIST ENVS ";"
benchmark: BENCHMARK NUMBER NUMBER user_statement
user_statement: show_current_user
| create_model_provider
| drop_model_provider
| set_default_llm
| set_default_vlm
| set_default_embedding
| set_default_reranker
| set_default_asr
| set_default_tts
| reset_default_llm
| reset_default_vlm
| reset_default_embedding
| reset_default_reranker
| reset_default_asr
| reset_default_tts
| create_user_dataset_with_parser
| create_user_dataset_with_pipeline
| drop_user_dataset
| list_user_datasets
| list_user_dataset_files
| list_user_agents
| list_user_chats
| create_user_chat
| drop_user_chat
| list_user_model_providers
| list_user_default_models
| import_docs_into_dataset
| search_on_datasets
show_current_user: SHOW CURRENT USER ";"
create_model_provider: CREATE MODEL PROVIDER quoted_string quoted_string ";"
drop_model_provider: DROP MODEL PROVIDER quoted_string ";"
set_default_llm: SET DEFAULT LLM quoted_string ";"
set_default_vlm: SET DEFAULT VLM quoted_string ";"
set_default_embedding: SET DEFAULT EMBEDDING quoted_string ";"
set_default_reranker: SET DEFAULT RERANKER quoted_string ";"
set_default_asr: SET DEFAULT ASR quoted_string ";"
set_default_tts: SET DEFAULT TTS quoted_string ";"
reset_default_llm: RESET DEFAULT LLM ";"
reset_default_vlm: RESET DEFAULT VLM ";"
reset_default_embedding: RESET DEFAULT EMBEDDING ";"
reset_default_reranker: RESET DEFAULT RERANKER ";"
reset_default_asr: RESET DEFAULT ASR ";"
reset_default_tts: RESET DEFAULT TTS ";"
list_user_datasets: LIST DATASETS ";"
create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";"
create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";"
drop_user_dataset: DROP DATASET quoted_string ";"
list_user_dataset_files: LIST FILES OF DATASET quoted_string ";"
list_user_agents: LIST AGENTS ";"
list_user_chats: LIST CHATS ";"
create_user_chat: CREATE CHAT quoted_string ";"
drop_user_chat: DROP CHAT quoted_string ";"
list_user_model_providers: LIST MODEL PROVIDERS ";"
list_user_default_models: LIST DEFAULT MODELS ";"
import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";"
search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";"
parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
parse_dataset_async: PARSE DATASET quoted_string ASYNC ";"
identifier_list: identifier ("," identifier)*
identifier: WORD
quoted_string: QUOTED_STRING
status: WORD
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
WORD: /[a-zA-Z0-9_\-\.]+/
NUMBER: /[0-9]+/
%import common.WS
%ignore WS
"""
class RAGFlowCLITransformer(Transformer):
def start(self, items):
return items[0]
def command(self, items):
return items[0]
def list_services(self, items):
result = {"type": "list_services"}
return result
def show_service(self, items):
service_id = int(items[2])
return {"type": "show_service", "number": service_id}
def startup_service(self, items):
service_id = int(items[2])
return {"type": "startup_service", "number": service_id}
def shutdown_service(self, items):
service_id = int(items[2])
return {"type": "shutdown_service", "number": service_id}
def restart_service(self, items):
service_id = int(items[2])
return {"type": "restart_service", "number": service_id}
def register_user(self, items):
user_name: str = items[2].children[0].strip("'\"")
nickname: str = items[4].children[0].strip("'\"")
password: str = items[6].children[0].strip("'\"")
return {"type": "register_user", "user_name": user_name, "nickname": nickname, "password": password}
def list_users(self, items):
return {"type": "list_users"}
def show_user(self, items):
user_name = items[2]
return {"type": "show_user", "user_name": user_name}
def drop_user(self, items):
user_name = items[2]
return {"type": "drop_user", "user_name": user_name}
def alter_user(self, items):
user_name = items[3]
new_password = items[4]
return {"type": "alter_user", "user_name": user_name, "password": new_password}
def create_user(self, items):
user_name = items[2]
password = items[3]
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
def activate_user(self, items):
user_name = items[3]
activate_status = items[4]
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
def list_datasets(self, items):
user_name = items[3]
return {"type": "list_datasets", "user_name": user_name}
def list_agents(self, items):
user_name = items[3]
return {"type": "list_agents", "user_name": user_name}
def create_role(self, items):
role_name = items[2]
if len(items) > 4:
description = items[4]
return {"type": "create_role", "role_name": role_name, "description": description}
else:
return {"type": "create_role", "role_name": role_name}
def drop_role(self, items):
role_name = items[2]
return {"type": "drop_role", "role_name": role_name}
def alter_role(self, items):
role_name = items[2]
description = items[5]
return {"type": "alter_role", "role_name": role_name, "description": description}
def list_roles(self, items):
return {"type": "list_roles"}
def show_role(self, items):
role_name = items[2]
return {"type": "show_role", "role_name": role_name}
def grant_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def revoke_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def alter_user_role(self, items):
user_name = items[2]
role_name = items[5]
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
def show_user_permission(self, items):
user_name = items[3]
return {"type": "show_user_permission", "user_name": user_name}
def show_version(self, items):
return {"type": "show_version"}
def grant_admin(self, items):
user_name = items[2]
return {"type": "grant_admin", "user_name": user_name}
def revoke_admin(self, items):
user_name = items[2]
return {"type": "revoke_admin", "user_name": user_name}
def generate_key(self, items):
user_name = items[4]
return {"type": "generate_key", "user_name": user_name}
def list_keys(self, items):
user_name = items[3]
return {"type": "list_keys", "user_name": user_name}
def drop_key(self, items):
key = items[2]
user_name = items[4]
return {"type": "drop_key", "key": key, "user_name": user_name}
def set_variable(self, items):
var_name = items[2]
var_value = items[3]
return {"type": "set_variable", "var_name": var_name, "var_value": var_value}
def show_variable(self, items):
var_name = items[2]
return {"type": "show_variable", "var_name": var_name}
def list_variables(self, items):
return {"type": "list_variables"}
def list_configs(self, items):
return {"type": "list_configs"}
def list_environments(self, items):
return {"type": "list_environments"}
def create_model_provider(self, items):
provider_name = items[3].children[0].strip("'\"")
provider_key = items[4].children[0].strip("'\"")
return {"type": "create_model_provider", "provider_name": provider_name, "provider_key": provider_key}
def drop_model_provider(self, items):
provider_name = items[3].children[0].strip("'\"")
return {"type": "drop_model_provider", "provider_name": provider_name}
def show_current_user(self, items):
return {"type": "show_current_user"}
def set_default_llm(self, items):
llm_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "llm_id", "model_id": llm_id}
def set_default_vlm(self, items):
vlm_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "img2txt_id", "model_id": vlm_id}
def set_default_embedding(self, items):
embedding_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "embd_id", "model_id": embedding_id}
def set_default_reranker(self, items):
reranker_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "reranker_id", "model_id": reranker_id}
def set_default_asr(self, items):
asr_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "asr_id", "model_id": asr_id}
def set_default_tts(self, items):
tts_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "tts_id", "model_id": tts_id}
def reset_default_llm(self, items):
return {"type": "reset_default_model", "model_type": "llm_id"}
def reset_default_vlm(self, items):
return {"type": "reset_default_model", "model_type": "img2txt_id"}
def reset_default_embedding(self, items):
return {"type": "reset_default_model", "model_type": "embd_id"}
def reset_default_reranker(self, items):
return {"type": "reset_default_model", "model_type": "reranker_id"}
def reset_default_asr(self, items):
return {"type": "reset_default_model", "model_type": "asr_id"}
def reset_default_tts(self, items):
return {"type": "reset_default_model", "model_type": "tts_id"}
def list_user_datasets(self, items):
return {"type": "list_user_datasets"}
def create_user_dataset_with_parser(self, items):
dataset_name = items[2].children[0].strip("'\"")
embedding = items[5].children[0].strip("'\"")
parser_type = items[7].children[0].strip("'\"")
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
"parser_type": parser_type}
def create_user_dataset_with_pipeline(self, items):
dataset_name = items[2].children[0].strip("'\"")
embedding = items[5].children[0].strip("'\"")
pipeline = items[7].children[0].strip("'\"")
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
"pipeline": pipeline}
def drop_user_dataset(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "drop_user_dataset", "dataset_name": dataset_name}
def list_user_dataset_files(self, items):
dataset_name = items[4].children[0].strip("'\"")
return {"type": "list_user_dataset_files", "dataset_name": dataset_name}
def list_user_agents(self, items):
return {"type": "list_user_agents"}
def list_user_chats(self, items):
return {"type": "list_user_chats"}
def create_user_chat(self, items):
chat_name = items[2].children[0].strip("'\"")
return {"type": "create_user_chat", "chat_name": chat_name}
def drop_user_chat(self, items):
chat_name = items[2].children[0].strip("'\"")
return {"type": "drop_user_chat", "chat_name": chat_name}
def list_user_model_providers(self, items):
return {"type": "list_user_model_providers"}
def list_user_default_models(self, items):
return {"type": "list_user_default_models"}
def parse_dataset_docs(self, items):
document_list_str = items[1].children[0].strip("'\"")
document_names = document_list_str.split(",")
if len(document_names) == 1:
document_names = document_names[0]
document_names = document_names.split(" ")
dataset_name = items[4].children[0].strip("'\"")
return {"type": "parse_dataset_docs", "dataset_name": dataset_name, "document_names": document_names}
def parse_dataset_sync(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "sync"}
def parse_dataset_async(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "async"}
def import_docs_into_dataset(self, items):
document_list_str = items[1].children[0].strip("'\"")
document_paths = document_list_str.split(",")
if len(document_paths) == 1:
document_paths = document_paths[0]
document_paths = document_paths.split(" ")
dataset_name = items[4].children[0].strip("'\"")
return {"type": "import_docs_into_dataset", "dataset_name": dataset_name, "document_paths": document_paths}
def search_on_datasets(self, items):
question = items[1].children[0].strip("'\"")
datasets_str = items[4].children[0].strip("'\"")
datasets = datasets_str.split(",")
if len(datasets) == 1:
datasets = datasets[0]
datasets = datasets.split(" ")
return {"type": "search_on_datasets", "datasets": datasets, "question": question}
def benchmark(self, items):
concurrency: int = int(items[1])
iterations: int = int(items[2])
command = items[3].children[0]
return {"type": "benchmark", "concurrency": concurrency, "iterations": iterations, "command": command}
def action_list(self, items):
return items
def meta_command(self, items):
command_name = str(items[0]).lower()
args = items[1:] if len(items) > 1 else []
# handle quoted parameter
parsed_args = []
for arg in args:
if hasattr(arg, "value"):
parsed_args.append(arg.value)
else:
parsed_args.append(str(arg))
return {"type": "meta", "command": command_name, "args": parsed_args}
def meta_command_name(self, items):
return items[0]
def meta_args(self, items):
return items

View File

@ -20,5 +20,8 @@ test = [
"requests-toolbelt>=1.0.0",
]
[tool.setuptools]
py-modules = ["ragflow_cli", "parser"]
[project.scripts]
ragflow-cli = "ragflow_cli:main"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

65
admin/client/user.py Normal file
View File

@ -0,0 +1,65 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from http_client import HttpClient
class AuthException(Exception):
def __init__(self, message, code=401):
super().__init__(message)
self.code = code
self.message = message
def encrypt_password(password_plain: str) -> str:
try:
from api.utils.crypt import crypt
except Exception as exc:
raise AuthException(
"Password encryption unavailable; install pycryptodomex (uv sync --python 3.12 --group test)."
) from exc
return crypt(password_plain)
def register_user(client: HttpClient, email: str, nickname: str, password: str) -> None:
password_enc = encrypt_password(password)
payload = {"email": email, "nickname": nickname, "password": password_enc}
res = client.request_json("POST", "/user/register", use_api_base=False, auth_kind=None, json_body=payload)
if res.get("code") == 0:
return
msg = res.get("message", "")
if "has already registered" in msg:
return
raise AuthException(f"Register failed: {msg}")
def login_user(client: HttpClient, server_type: str, email: str, password: str) -> str:
password_enc = encrypt_password(password)
payload = {"email": email, "password": password_enc}
if server_type == "admin":
response = client.request("POST", "/admin/login", use_api_base=True, auth_kind=None, json_body=payload)
else:
response = client.request("POST", "/user/login", use_api_base=False, auth_kind=None, json_body=payload)
try:
res = response.json()
except Exception as exc:
raise AuthException(f"Login failed: invalid JSON response ({exc})") from exc
if res.get("code") != 0:
raise AuthException(f"Login failed: {res.get('message')}")
token = response.headers.get("Authorization")
if not token:
raise AuthException("Login failed: missing Authorization header")
return token