refactor: 新增用户脚本服务器代码

This commit is contained in:
Quan 2025-11-08 21:30:18 +08:00
parent f72b163eb8
commit 1e5ffc1907
11 changed files with 148 additions and 19 deletions

View File

@ -22,6 +22,7 @@ dependencies = [
"rookiepy>=0.5.6",
"textual>=6.5.0",
"uvicorn>=0.38.0",
"websockets>=15.0.1",
]
[project.urls]

View File

@ -26,3 +26,5 @@ textual==6.5.0
# via xhs-downloader (pyproject.toml)
uvicorn==0.38.0
# via xhs-downloader (pyproject.toml)
websockets==15.0.1
# via xhs-downloader (pyproject.toml)

View File

@ -38,6 +38,7 @@ class XHSDownloader(App):
**self.parameter,
_print=False,
)
self.APP.init_script_server()
async def on_mount(self) -> None:
self.theme = "nord"
@ -76,6 +77,7 @@ class XHSDownloader(App):
await self.APP.close()
self.__initialization()
await self.__aenter__()
await self.APP.switch_script_server()
self.uninstall_screen("index")
self.uninstall_screen("setting")
self.uninstall_screen("loading")

View File

@ -162,6 +162,15 @@ class Setting(Screen):
),
classes="horizontal-layout",
),
Label(),
Container(
Checkbox(
_("脚本服务器开关"),
id="script_server",
value=self.data["script_server"],
),
classes="horizontal-layout",
),
Container(
Label(
_("图片下载格式"),
@ -235,6 +244,7 @@ class Setting(Screen):
"download_record": self.query_one("#download_record").value,
"author_archive": self.query_one("#author_archive").value,
"write_mtime": self.query_one("#write_mtime").value,
"script_server": self.query_one("#script_server").value,
}
)

View File

@ -1,4 +1,13 @@
from asyncio import Event, Queue, QueueEmpty, create_task, gather, sleep
from asyncio import (
Event,
Queue,
QueueEmpty,
create_task,
gather,
sleep,
Future,
CancelledError,
)
from contextlib import suppress
from datetime import datetime
from re import compile
@ -14,14 +23,14 @@ from pydantic import Field
from pyperclip import copy, paste
from uvicorn import Config, Server
from source.expansion import (
from ..expansion import (
BrowserCookie,
Cleaner,
Converter,
Namespace,
beautify_string,
)
from source.module import (
from ..module import (
__VERSION__,
ERROR,
MASTER,
@ -39,8 +48,9 @@ from source.module import (
MapRecorder,
logging,
# sleep_time,
ScriptServer,
)
from source.translation import _, switch_language
from ..translation import _, switch_language
from ..module import Mapping
from .download import Download
@ -111,6 +121,7 @@ class XHS:
write_mtime=False,
language="zh_CN",
read_cookie: int | str = None,
script_server: bool = False,
_print: bool = True,
*args,
**kwargs,
@ -136,6 +147,7 @@ class XHS:
folder_mode,
author_archive,
write_mtime,
script_server,
_print,
self.CLEANER,
)
@ -155,6 +167,7 @@ class XHS:
self.clipboard_cache: str = ""
self.queue = Queue()
self.event = Event()
self.script = None
def __extract_image(self, container: dict, data: Namespace):
container["下载地址"], container["动图地址"] = self.image.get_image_link(
@ -474,6 +487,7 @@ class XHS:
await self.id_recorder.__aexit__(exc_type, exc_value, traceback)
await self.data_recorder.__aexit__(exc_type, exc_value, traceback)
await self.map_recorder.__aexit__(exc_type, exc_value, traceback)
await self.stop_script_server()
await self.close()
async def close(self):
@ -796,3 +810,46 @@ class XHS:
else:
msg = _("获取小红书作品数据失败")
return msg, data
def init_script_server(
self,
):
if self.manager.script_server:
self.run_script_server()
async def switch_script_server(
self,
switch: bool = None,
):
if switch is None:
switch = self.manager.script_server
if switch:
self.run_script_server()
else:
await self.stop_script_server()
def run_script_server(
self,
host="0.0.0.0",
port=5556,
):
if not self.script:
self.script = create_task(self._run_script_server(host, port))
async def _run_script_server(
self,
host="0.0.0.0",
port=5556,
):
async with ScriptServer(self, host, port):
await Future()
async def stop_script_server(self):
if self.script:
self.script.cancel()
with suppress(CancelledError):
await self.script
self.script = None
async def _script_server_debug(self):
await self.switch_script_server(self.manager.script_server)

View File

@ -39,3 +39,4 @@ from .tools import (
sleep_time,
retry_limited,
)
from .script import ScriptServer

View File

@ -70,6 +70,7 @@ class Manager:
folder_mode: bool,
author_archive: bool,
write_mtime: bool,
script_server: bool,
_print: bool,
cleaner: "Cleaner",
):
@ -126,6 +127,7 @@ class Manager:
self.live_download = self.check_bool(live_download, True)
self.author_archive = self.check_bool(author_archive, False)
self.write_mtime = self.check_bool(write_mtime, False)
self.script_server = self.check_bool(script_server, False)
self.create_folder()
def __check_path(self, path: str) -> Path:
@ -282,8 +284,12 @@ class Manager:
self.folder.mkdir(exist_ok=True)
self.temp.mkdir(exist_ok=True)
def compatible(self,):
if self.path == self.root and (
old := self.path.parent.joinpath(self.folder.name)
).exists() and not self.folder.exists():
def compatible(
self,
):
if (
self.path == self.root
and (old := self.path.parent.joinpath(self.folder.name)).exists()
and not self.folder.exists()
):
move(old, self.folder)

47
source/module/script.py Normal file
View File

@ -0,0 +1,47 @@
from contextlib import suppress
from websockets import ConnectionClosed, serve
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..application import XHS
class ScriptServer:
def __init__(
self,
core: "XHS",
host="0.0.0.0",
port=5556,
):
self.core = core
self.host = host
self.port = port
self.server = None
async def handler(self, websocket):
with suppress(ConnectionClosed):
async for message in websocket:
print(f"收到消息: {message}")
await websocket.send("消息已接收")
async def start(self):
"""启动服务器"""
self.server = await serve(
self.handler,
self.host,
self.port,
)
async def stop(self):
"""停止服务器"""
if self.server:
self.server.close()
await self.server.wait_closed()
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()

View File

@ -30,6 +30,7 @@ class Settings:
"author_archive": False, # 是否按作者归档
"write_mtime": False, # 是否写入修改时间
"language": "zh_CN", # 语言设置
"script_server": False, # 是否启用脚本服务器
}
# 根据操作系统设置编码格式
encode = "UTF-8-SIG" if system() == "Windows" else "UTF-8"

View File

@ -25,13 +25,13 @@ HEADERS = {
"user-agent": USERAGENT,
}
MASTER = "b #fff200"
PROMPT = "b turquoise2"
GENERAL = "b bright_white"
PROGRESS = "b bright_magenta"
ERROR = "b bright_red"
WARNING = "b bright_yellow"
INFO = "b bright_green"
MASTER = "#fff200"
PROMPT = "turquoise2"
GENERAL = "bright_white"
PROGRESS = "bright_magenta"
ERROR = "bright_red"
WARNING = "bright_yellow"
INFO = "bright_green"
FILE_SIGNATURES: tuple[
tuple[

10
uv.lock generated
View File

@ -1327,6 +1327,7 @@ dependencies = [
{ name = "rookiepy" },
{ name = "textual" },
{ name = "uvicorn" },
{ name = "websockets" },
]
[package.dev-dependencies]
@ -1340,15 +1341,16 @@ requires-dist = [
{ name = "aiosqlite", specifier = ">=0.21.0" },
{ name = "click", specifier = ">=8.3.0" },
{ name = "emoji", specifier = ">=2.15.0" },
{ name = "fastapi", specifier = ">=0.119.0" },
{ name = "fastmcp", specifier = ">=2.12.4" },
{ name = "fastapi", specifier = ">=0.121.0" },
{ name = "fastmcp", specifier = ">=2.13.0" },
{ name = "httpx", extras = ["socks"], specifier = ">=0.28.1" },
{ name = "lxml", specifier = ">=6.0.2" },
{ name = "pyperclip", specifier = ">=1.11.0" },
{ name = "pyyaml", specifier = ">=6.0.3" },
{ name = "rookiepy", specifier = ">=0.5.6" },
{ name = "textual", specifier = ">=6.3.0" },
{ name = "uvicorn", specifier = ">=0.37.0" },
{ name = "textual", specifier = ">=6.5.0" },
{ name = "uvicorn", specifier = ">=0.38.0" },
{ name = "websockets", specifier = ">=15.0.1" },
]
[package.metadata.requires-dev]