refactor: 重构运行提示传递逻辑

This commit is contained in:
Quan
2025-12-16 12:55:52 +08:00
parent 8039c6a5db
commit 855113c94d
11 changed files with 117 additions and 181 deletions

View File

@@ -21,6 +21,7 @@ from pydantic import Field
from types import SimpleNamespace
from pyperclip import copy, paste
from uvicorn import Config, Server
from typing import Callable
from ..expansion import (
BrowserCookie,
@@ -48,6 +49,7 @@ from ..module import (
logging,
# sleep_time,
ScriptServer,
INFO,
)
from ..translation import _, switch_language
@@ -57,6 +59,7 @@ from .explore import Explore
from .image import Image
from .request import Html
from .video import Video
from rich import print
__all__ = ["XHS"]
@@ -79,6 +82,19 @@ def data_cache(function):
return inner
class Print:
def __init__(
self,
func: Callable = print,
):
self.func = func
def __call__(
self,
):
return self.func
class XHS:
VERSION_MAJOR = VERSION_MAJOR
VERSION_MINOR = VERSION_MINOR
@@ -123,11 +139,11 @@ class XHS:
script_server: bool = False,
script_host="0.0.0.0",
script_port=5558,
_print: bool = True,
*args,
**kwargs,
):
switch_language(language)
self.print = Print()
self.manager = Manager(
ROOT,
work_path,
@@ -149,8 +165,8 @@ class XHS:
author_archive,
write_mtime,
script_server,
_print,
self.CLEANER,
self.print,
)
self.mapping_data = mapping_data or {}
self.map_recorder = MapRecorder(
@@ -190,14 +206,12 @@ class XHS:
container: dict,
download: bool,
index,
log,
bar,
count: SimpleNamespace,
):
name = self.__naming_rules(container)
if (u := container["下载地址"]) and download:
if await self.skip_download(i := container["作品ID"]):
logging(log, _("作品 {0} 存在下载记录,跳过下载").format(i))
self.logging(_("作品 {0} 存在下载记录,跳过下载").format(i))
count.skip += 1
else:
__, result = await self.download.run(
@@ -210,8 +224,6 @@ class XHS:
name,
container["作品类型"],
container["时间戳"],
log,
bar,
)
if result:
count.success += 1
@@ -221,7 +233,7 @@ class XHS:
else:
count.fail += 1
elif not u:
logging(log, _("提取作品文件下载地址失败"), ERROR)
self.logging(_("提取作品文件下载地址失败"), ERROR)
count.fail += 1
await self.save_data(container)
@@ -247,12 +259,14 @@ class XHS:
url: str,
download=False,
index: list | tuple = None,
log=None,
bar=None,
data=True,
) -> list[dict]:
if not (urls := await self.extract_links(url, log)):
logging(log, _("提取小红书作品链接失败"), WARNING)
if not (
urls := await self.extract_links(
url,
)
):
self.logging(_("提取小红书作品链接失败"), WARNING)
return []
statistics = SimpleNamespace(
all=len(urls),
@@ -260,14 +274,12 @@ class XHS:
fail=0,
skip=0,
)
logging(log, _("{0} 个小红书作品待处理...").format(statistics.all))
self.logging(_("{0} 个小红书作品待处理...").format(statistics.all))
result = [
await self.__deal_extract(
i,
download,
index,
log,
bar,
data,
count=statistics,
)
@@ -275,17 +287,14 @@ class XHS:
]
self.show_statistics(
statistics,
log,
)
return result
@staticmethod
def show_statistics(
self,
statistics: SimpleNamespace,
log=None,
) -> None:
logging(
log,
self.logging(
_("共处理 {0} 个作品,成功 {1} 个,失败 {2} 个,跳过 {3}").format(
statistics.all,
statistics.success,
@@ -299,21 +308,19 @@ class XHS:
url: str,
download=True,
index: list | tuple = None,
log=None,
bar=None,
data=False,
) -> None:
url = await self.extract_links(url, log)
url = await self.extract_links(
url,
)
if not url:
logging(log, _("提取小红书作品链接失败"), WARNING)
self.logging(_("提取小红书作品链接失败"), WARNING)
return
if index:
await self.__deal_extract(
url[0],
download,
index,
log,
bar,
data,
)
else:
@@ -328,8 +335,6 @@ class XHS:
u,
download,
index,
log,
bar,
data,
count=statistics,
)
@@ -337,17 +342,18 @@ class XHS:
]
self.show_statistics(
statistics,
log,
)
async def extract_links(self, url: str, log) -> list:
async def extract_links(
self,
url: str,
) -> list:
urls = []
for i in url.split():
if u := self.SHORT.search(i):
i = await self.html.request_url(
u.group(),
False,
log,
)
if u := self.SHARE.search(i):
urls.append(u.group())
@@ -369,7 +375,6 @@ class XHS:
async def _get_html_data(
self,
url: str,
log,
data: bool,
cookie: str = None,
proxy: str = None,
@@ -382,19 +387,18 @@ class XHS:
) -> tuple[str, Namespace | dict]:
if await self.skip_download(id_ := self.__extract_link_id(url)) and not data:
msg = _("作品 {0} 存在下载记录,跳过处理").format(id_)
logging(log, msg)
self.logging(msg)
count.skip += 1
return id_, {"message": msg}
logging(log, _("开始处理作品:{0}").format(id_))
self.logging(_("开始处理作品:{0}").format(id_))
html = await self.html.request_url(
url,
log=log,
cookie=cookie,
proxy=proxy,
)
namespace = self.__generate_data_object(html)
if not namespace:
logging(log, _("{0} 获取数据失败").format(id_), ERROR)
self.logging(_("{0} 获取数据失败").format(id_), ERROR)
count.fail += 1
return id_, {}
return id_, namespace
@@ -403,13 +407,11 @@ class XHS:
self,
namespace: Namespace,
id_: str,
log,
count,
):
data = self.explore.run(namespace)
# logging(log, data) # 调试代码
if not data:
logging(log, _("{0} 提取数据失败").format(id_), ERROR)
self.logging(_("{0} 提取数据失败").format(id_), ERROR)
count.fail += 1
return {}
return data
@@ -421,8 +423,6 @@ class XHS:
id_: str,
download: bool,
index: list | tuple | None,
log,
bar,
count: SimpleNamespace,
):
if data["作品类型"] == _("视频"):
@@ -433,16 +433,16 @@ class XHS:
}:
self.__extract_image(data, namespace)
else:
logging(log, _("未知的作品类型:{0}").format(id_), WARNING)
self.logging(_("未知的作品类型:{0}").format(id_), WARNING)
data["下载地址"] = []
data["动图地址"] = []
await self.update_author_nickname(data, log)
await self.update_author_nickname(
data,
)
await self.__download_files(
data,
download,
index,
log,
bar,
count,
)
# await sleep_time()
@@ -453,8 +453,6 @@ class XHS:
url: str,
download: bool,
index: list | tuple | None,
log,
bar,
data: bool,
cookie: str = None,
proxy: str = None,
@@ -467,7 +465,6 @@ class XHS:
):
id_, namespace = await self._get_html_data(
url,
log,
data,
cookie,
proxy,
@@ -479,7 +476,6 @@ class XHS:
data := self._extract_data(
namespace,
id_,
log,
count,
)
):
@@ -490,19 +486,15 @@ class XHS:
id_,
download,
index,
log,
bar,
count,
)
logging(log, _("作品处理完成:{0}").format(id_))
self.logging(_("作品处理完成:{0}").format(id_))
return data
async def deal_script_tasks(
self,
data: dict,
index: list | tuple | None,
log=None,
bar=None,
count=SimpleNamespace(
all=0,
success=0,
@@ -516,7 +508,6 @@ class XHS:
data := self._extract_data(
namespace,
id_,
log,
count,
)
):
@@ -527,8 +518,6 @@ class XHS:
id_,
True,
index,
log,
bar,
count,
)
@@ -539,7 +528,6 @@ class XHS:
async def update_author_nickname(
self,
container: dict,
log,
):
if a := self.CLEANER.filter_name(
self.mapping_data.get(i := container["作者ID"], "")
@@ -550,7 +538,6 @@ class XHS:
await self.mapping.update_cache(
i,
container["作者昵称"],
log,
)
@staticmethod
@@ -602,13 +589,10 @@ class XHS:
async def monitor(
self,
delay=1,
download=False,
log=None,
bar=None,
data=True,
download=True,
data=False,
) -> None:
logging(
None,
self.logging(
_(
"程序会自动读取并提取剪贴板中的小红书作品链接,并自动下载链接对应的作品文件,如需关闭,请点击关闭按钮,或者向剪贴板写入 “close” 文本!"
),
@@ -618,7 +602,7 @@ class XHS:
copy("")
await gather(
self.__get_link(delay),
self.__receive_link(delay, download, None, log, bar, data),
self.__receive_link(delay, download=download, index=None, data=data),
)
async def __get_link(self, delay: int):
@@ -988,3 +972,10 @@ class XHS:
await self.switch_script_server(
switch=self.manager.script_server,
)
def logging(self, text, style=INFO):
logging(
self.print,
text,
style,
)

View File

@@ -44,6 +44,7 @@ class Download:
manager: "Manager",
):
self.manager = manager
self.print = manager.print
self.folder = manager.folder
self.temp = manager.temp
self.chunk = manager.chunk
@@ -76,8 +77,6 @@ class Download:
filename: str,
type_: str,
mtime: int,
log,
bar,
) -> tuple[Path, list[Any]]:
path = self.__generate_path(nickname, filename)
if type_ == _("视频"):
@@ -85,7 +84,6 @@ class Download:
urls,
path,
filename,
log,
)
elif type_ in {
_("图文"),
@@ -97,7 +95,6 @@ class Download:
index,
path,
filename,
log,
)
else:
raise ValueError
@@ -108,8 +105,6 @@ class Download:
name,
format_,
mtime,
log,
bar,
)
for url, name, format_ in tasks
]
@@ -127,12 +122,18 @@ class Download:
return path
def __ready_download_video(
self, urls: list[str], path: Path, name: str, log
self,
urls: list[str],
path: Path,
name: str,
) -> list:
if not self.video_download:
logging(log, _("视频作品下载功能已关闭,跳过下载"))
logging(self.print, _("视频作品下载功能已关闭,跳过下载"))
return []
if self.__check_exists_path(path, f"{name}.{self.video_format}", log):
if self.__check_exists_path(
path,
f"{name}.{self.video_format}",
):
return []
return [(urls[0], name, self.video_format)]
@@ -143,11 +144,10 @@ class Download:
index: list | tuple | None,
path: Path,
name: str,
log,
) -> list:
tasks = []
if not self.image_download:
logging(log, _("图文作品下载功能已关闭,跳过下载"))
logging(self.print, _("图文作品下载功能已关闭,跳过下载"))
return tasks
for i, j in enumerate(zip(urls, lives), start=1):
if index and i not in index:
@@ -157,7 +157,6 @@ class Download:
self.__check_exists_path(
path,
f"{file}.{s}",
log,
)
for s in self.image_format_list
):
@@ -168,32 +167,29 @@ class Download:
or self.__check_exists_path(
path,
f"{file}.{self.live_format}",
log,
)
):
continue
tasks.append([j[1], file, self.live_format])
return tasks
@staticmethod
def __check_exists_glob(
self,
path: Path,
name: str,
log,
) -> bool:
if any(path.glob(name)):
logging(log, _("{0} 文件已存在,跳过下载").format(name))
logging(self.print, _("{0} 文件已存在,跳过下载").format(name))
return True
return False
@staticmethod
def __check_exists_path(
self,
path: Path,
name: str,
log,
) -> bool:
if path.joinpath(name).exists():
logging(log, _("{0} 文件已存在,跳过下载").format(name))
logging(self.print, _("{0} 文件已存在,跳过下载").format(name))
return True
return False
@@ -205,26 +201,9 @@ class Download:
name: str,
format_: str,
mtime: int,
log,
bar,
):
async with self.SEMAPHORE:
headers = self.headers.copy()
# try:
# length, suffix = await self.__head_file(
# url,
# headers,
# format_,
# )
# except HTTPError as error:
# logging(
# log,
# _(
# "网络异常,{0} 请求失败,错误信息: {1}").format(name, repr(error)),
# ERROR,
# )
# return False
# temp = self.temp.joinpath(f"{name}.{suffix}")
temp = self.temp.joinpath(f"{name}.{format_}")
self.__update_headers_range(
headers,
@@ -258,7 +237,6 @@ class Download:
name,
# suffix,
format_,
log,
)
self.manager.move(
temp,
@@ -267,12 +245,12 @@ class Download:
self.write_mtime,
)
# self.__create_progress(bar, None)
logging(log, _("文件 {0} 下载成功").format(real.name))
logging(self.print, _("文件 {0} 下载成功").format(real.name))
return True
except HTTPError as error:
# self.__create_progress(bar, None)
logging(
log,
self.print,
_("网络异常,{0} 下载失败,错误信息: {1}").format(
name, repr(error)
),
@@ -282,7 +260,7 @@ class Download:
except CacheError as error:
self.manager.delete(temp)
logging(
log,
self.print,
str(error),
ERROR,
)
@@ -335,13 +313,12 @@ class Download:
headers["Range"] = f"bytes={(p := self.__get_resume_byte_position(file))}-"
return p
@staticmethod
async def __suffix_with_file(
self,
temp: Path,
path: Path,
name: str,
default_suffix: str,
log,
) -> Path:
try:
async with open(temp, "rb") as f:
@@ -351,7 +328,7 @@ class Download:
return path.joinpath(f"{name}.{suffix}")
except Exception as error:
logging(
log,
self.print,
_("文件 {0} 格式判断失败,错误信息:{1}").format(
temp.name, repr(error)
),

View File

@@ -17,6 +17,7 @@ class Html:
self,
manager: "Manager",
):
self.print = manager.print
self.retry = manager.retry
self.client = manager.request_client
self.headers = manager.headers
@@ -27,7 +28,6 @@ class Html:
self,
url: str,
content=True,
log=None,
cookie: str = None,
proxy: str = None,
**kwargs,
@@ -62,7 +62,9 @@ class Html:
raise ValueError
except HTTPError as error:
logging(
log, _("网络异常,{0} 请求失败: {1}").format(url, repr(error)), ERROR
self.print,
_("网络异常,{0} 请求失败: {1}").format(url, repr(error)),
ERROR,
)
return ""