refactor: 重构运行提示传递逻辑

This commit is contained in:
Quan 2025-12-16 12:55:52 +08:00
parent 8039c6a5db
commit 855113c94d
11 changed files with 117 additions and 181 deletions

View File

@ -77,13 +77,12 @@ class Index(Screen):
self.title = PROJECT self.title = PROJECT
self.url = self.query_one(Input) self.url = self.query_one(Input)
self.tip = self.query_one(RichLog) self.tip = self.query_one(RichLog)
self.xhs.print.func = self.tip
self.tip.write( self.tip.write(
Text(_("免责声明\n") + f"\n{'>' * 50}", style=MASTER), Text(_("免责声明\n") + f"\n{'>' * 50}", style=MASTER),
scroll_end=True, scroll_end=True,
) )
self.xhs.manager.print_proxy_tip( self.xhs.manager.print_proxy_tip()
log=self.tip,
)
@on(Button.Pressed, "#deal") @on(Button.Pressed, "#deal")
async def deal_button(self): async def deal_button(self):
@ -114,7 +113,6 @@ class Index(Screen):
await self.xhs.extract( await self.xhs.extract(
self.url.value, self.url.value,
True, True,
log=self.tip,
data=False, data=False,
) )
): ):

View File

@ -8,7 +8,6 @@ from textual.widgets import Button, Footer, Header, Label, RichLog
from ..application import XHS from ..application import XHS
from ..module import ( from ..module import (
INFO, INFO,
MASTER,
PROJECT, PROJECT,
) )
from ..translation import _ from ..translation import _
@ -42,24 +41,12 @@ class Monitor(Screen):
@work(exclusive=True) @work(exclusive=True)
async def run_monitor(self): async def run_monitor(self):
await self.xhs.monitor( await self.xhs.monitor()
download=True,
log=self.query_one(RichLog),
data=False,
)
await self.action_close() await self.action_close()
def on_mount(self) -> None: def on_mount(self) -> None:
self.title = PROJECT self.title = PROJECT
self.query_one(RichLog).write( self.xhs.print.func = self.query_one(RichLog)
Text(
_(
"程序会自动读取并提取剪贴板中的小红书作品链接,并自动下载链接对应的作品文件,如需关闭,请点击关闭按钮,或者向剪贴板写入 “close” 文本!"
),
style=MASTER,
),
scroll_end=True,
)
self.run_monitor() self.run_monitor()
async def action_close(self): async def action_close(self):

View File

@ -40,7 +40,6 @@ class Record(ModalScreen):
async def delete(self, text: str): async def delete(self, text: str):
text = await self.xhs.extract_links( text = await self.xhs.extract_links(
text, text,
None,
) )
text = self.xhs.extract_id(text) text = self.xhs.extract_id(text)
await self.xhs.id_recorder.delete(text) await self.xhs.id_recorder.delete(text)

View File

@ -34,7 +34,6 @@ class Update(ModalScreen):
url = await self.xhs.html.request_url( url = await self.xhs.html.request_url(
RELEASES, RELEASES,
False, False,
None,
timeout=5, timeout=5,
) )
version = url.split("/")[-1] version = url.split("/")[-1]

View File

@ -21,6 +21,7 @@ from pydantic import Field
from types import SimpleNamespace from types import SimpleNamespace
from pyperclip import copy, paste from pyperclip import copy, paste
from uvicorn import Config, Server from uvicorn import Config, Server
from typing import Callable
from ..expansion import ( from ..expansion import (
BrowserCookie, BrowserCookie,
@ -48,6 +49,7 @@ from ..module import (
logging, logging,
# sleep_time, # sleep_time,
ScriptServer, ScriptServer,
INFO,
) )
from ..translation import _, switch_language from ..translation import _, switch_language
@ -57,6 +59,7 @@ from .explore import Explore
from .image import Image from .image import Image
from .request import Html from .request import Html
from .video import Video from .video import Video
from rich import print
__all__ = ["XHS"] __all__ = ["XHS"]
@ -79,6 +82,19 @@ def data_cache(function):
return inner return inner
class Print:
def __init__(
self,
func: Callable = print,
):
self.func = func
def __call__(
self,
):
return self.func
class XHS: class XHS:
VERSION_MAJOR = VERSION_MAJOR VERSION_MAJOR = VERSION_MAJOR
VERSION_MINOR = VERSION_MINOR VERSION_MINOR = VERSION_MINOR
@ -123,11 +139,11 @@ class XHS:
script_server: bool = False, script_server: bool = False,
script_host="0.0.0.0", script_host="0.0.0.0",
script_port=5558, script_port=5558,
_print: bool = True,
*args, *args,
**kwargs, **kwargs,
): ):
switch_language(language) switch_language(language)
self.print = Print()
self.manager = Manager( self.manager = Manager(
ROOT, ROOT,
work_path, work_path,
@ -149,8 +165,8 @@ class XHS:
author_archive, author_archive,
write_mtime, write_mtime,
script_server, script_server,
_print,
self.CLEANER, self.CLEANER,
self.print,
) )
self.mapping_data = mapping_data or {} self.mapping_data = mapping_data or {}
self.map_recorder = MapRecorder( self.map_recorder = MapRecorder(
@ -190,14 +206,12 @@ class XHS:
container: dict, container: dict,
download: bool, download: bool,
index, index,
log,
bar,
count: SimpleNamespace, count: SimpleNamespace,
): ):
name = self.__naming_rules(container) name = self.__naming_rules(container)
if (u := container["下载地址"]) and download: if (u := container["下载地址"]) and download:
if await self.skip_download(i := container["作品ID"]): if await self.skip_download(i := container["作品ID"]):
logging(log, _("作品 {0} 存在下载记录,跳过下载").format(i)) self.logging(_("作品 {0} 存在下载记录,跳过下载").format(i))
count.skip += 1 count.skip += 1
else: else:
__, result = await self.download.run( __, result = await self.download.run(
@ -210,8 +224,6 @@ class XHS:
name, name,
container["作品类型"], container["作品类型"],
container["时间戳"], container["时间戳"],
log,
bar,
) )
if result: if result:
count.success += 1 count.success += 1
@ -221,7 +233,7 @@ class XHS:
else: else:
count.fail += 1 count.fail += 1
elif not u: elif not u:
logging(log, _("提取作品文件下载地址失败"), ERROR) self.logging(_("提取作品文件下载地址失败"), ERROR)
count.fail += 1 count.fail += 1
await self.save_data(container) await self.save_data(container)
@ -247,12 +259,14 @@ class XHS:
url: str, url: str,
download=False, download=False,
index: list | tuple = None, index: list | tuple = None,
log=None,
bar=None,
data=True, data=True,
) -> list[dict]: ) -> list[dict]:
if not (urls := await self.extract_links(url, log)): if not (
logging(log, _("提取小红书作品链接失败"), WARNING) urls := await self.extract_links(
url,
)
):
self.logging(_("提取小红书作品链接失败"), WARNING)
return [] return []
statistics = SimpleNamespace( statistics = SimpleNamespace(
all=len(urls), all=len(urls),
@ -260,14 +274,12 @@ class XHS:
fail=0, fail=0,
skip=0, skip=0,
) )
logging(log, _("{0} 个小红书作品待处理...").format(statistics.all)) self.logging(_("{0} 个小红书作品待处理...").format(statistics.all))
result = [ result = [
await self.__deal_extract( await self.__deal_extract(
i, i,
download, download,
index, index,
log,
bar,
data, data,
count=statistics, count=statistics,
) )
@ -275,17 +287,14 @@ class XHS:
] ]
self.show_statistics( self.show_statistics(
statistics, statistics,
log,
) )
return result return result
@staticmethod
def show_statistics( def show_statistics(
self,
statistics: SimpleNamespace, statistics: SimpleNamespace,
log=None,
) -> None: ) -> None:
logging( self.logging(
log,
_("共处理 {0} 个作品,成功 {1} 个,失败 {2} 个,跳过 {3}").format( _("共处理 {0} 个作品,成功 {1} 个,失败 {2} 个,跳过 {3}").format(
statistics.all, statistics.all,
statistics.success, statistics.success,
@ -299,21 +308,19 @@ class XHS:
url: str, url: str,
download=True, download=True,
index: list | tuple = None, index: list | tuple = None,
log=None,
bar=None,
data=False, data=False,
) -> None: ) -> None:
url = await self.extract_links(url, log) url = await self.extract_links(
url,
)
if not url: if not url:
logging(log, _("提取小红书作品链接失败"), WARNING) self.logging(_("提取小红书作品链接失败"), WARNING)
return return
if index: if index:
await self.__deal_extract( await self.__deal_extract(
url[0], url[0],
download, download,
index, index,
log,
bar,
data, data,
) )
else: else:
@ -328,8 +335,6 @@ class XHS:
u, u,
download, download,
index, index,
log,
bar,
data, data,
count=statistics, count=statistics,
) )
@ -337,17 +342,18 @@ class XHS:
] ]
self.show_statistics( self.show_statistics(
statistics, statistics,
log,
) )
async def extract_links(self, url: str, log) -> list: async def extract_links(
self,
url: str,
) -> list:
urls = [] urls = []
for i in url.split(): for i in url.split():
if u := self.SHORT.search(i): if u := self.SHORT.search(i):
i = await self.html.request_url( i = await self.html.request_url(
u.group(), u.group(),
False, False,
log,
) )
if u := self.SHARE.search(i): if u := self.SHARE.search(i):
urls.append(u.group()) urls.append(u.group())
@ -369,7 +375,6 @@ class XHS:
async def _get_html_data( async def _get_html_data(
self, self,
url: str, url: str,
log,
data: bool, data: bool,
cookie: str = None, cookie: str = None,
proxy: str = None, proxy: str = None,
@ -382,19 +387,18 @@ class XHS:
) -> tuple[str, Namespace | dict]: ) -> tuple[str, Namespace | dict]:
if await self.skip_download(id_ := self.__extract_link_id(url)) and not data: if await self.skip_download(id_ := self.__extract_link_id(url)) and not data:
msg = _("作品 {0} 存在下载记录,跳过处理").format(id_) msg = _("作品 {0} 存在下载记录,跳过处理").format(id_)
logging(log, msg) self.logging(msg)
count.skip += 1 count.skip += 1
return id_, {"message": msg} return id_, {"message": msg}
logging(log, _("开始处理作品:{0}").format(id_)) self.logging(_("开始处理作品:{0}").format(id_))
html = await self.html.request_url( html = await self.html.request_url(
url, url,
log=log,
cookie=cookie, cookie=cookie,
proxy=proxy, proxy=proxy,
) )
namespace = self.__generate_data_object(html) namespace = self.__generate_data_object(html)
if not namespace: if not namespace:
logging(log, _("{0} 获取数据失败").format(id_), ERROR) self.logging(_("{0} 获取数据失败").format(id_), ERROR)
count.fail += 1 count.fail += 1
return id_, {} return id_, {}
return id_, namespace return id_, namespace
@ -403,13 +407,11 @@ class XHS:
self, self,
namespace: Namespace, namespace: Namespace,
id_: str, id_: str,
log,
count, count,
): ):
data = self.explore.run(namespace) data = self.explore.run(namespace)
# logging(log, data) # 调试代码
if not data: if not data:
logging(log, _("{0} 提取数据失败").format(id_), ERROR) self.logging(_("{0} 提取数据失败").format(id_), ERROR)
count.fail += 1 count.fail += 1
return {} return {}
return data return data
@ -421,8 +423,6 @@ class XHS:
id_: str, id_: str,
download: bool, download: bool,
index: list | tuple | None, index: list | tuple | None,
log,
bar,
count: SimpleNamespace, count: SimpleNamespace,
): ):
if data["作品类型"] == _("视频"): if data["作品类型"] == _("视频"):
@ -433,16 +433,16 @@ class XHS:
}: }:
self.__extract_image(data, namespace) self.__extract_image(data, namespace)
else: else:
logging(log, _("未知的作品类型:{0}").format(id_), WARNING) self.logging(_("未知的作品类型:{0}").format(id_), WARNING)
data["下载地址"] = [] data["下载地址"] = []
data["动图地址"] = [] data["动图地址"] = []
await self.update_author_nickname(data, log) await self.update_author_nickname(
data,
)
await self.__download_files( await self.__download_files(
data, data,
download, download,
index, index,
log,
bar,
count, count,
) )
# await sleep_time() # await sleep_time()
@ -453,8 +453,6 @@ class XHS:
url: str, url: str,
download: bool, download: bool,
index: list | tuple | None, index: list | tuple | None,
log,
bar,
data: bool, data: bool,
cookie: str = None, cookie: str = None,
proxy: str = None, proxy: str = None,
@ -467,7 +465,6 @@ class XHS:
): ):
id_, namespace = await self._get_html_data( id_, namespace = await self._get_html_data(
url, url,
log,
data, data,
cookie, cookie,
proxy, proxy,
@ -479,7 +476,6 @@ class XHS:
data := self._extract_data( data := self._extract_data(
namespace, namespace,
id_, id_,
log,
count, count,
) )
): ):
@ -490,19 +486,15 @@ class XHS:
id_, id_,
download, download,
index, index,
log,
bar,
count, count,
) )
logging(log, _("作品处理完成:{0}").format(id_)) self.logging(_("作品处理完成:{0}").format(id_))
return data return data
async def deal_script_tasks( async def deal_script_tasks(
self, self,
data: dict, data: dict,
index: list | tuple | None, index: list | tuple | None,
log=None,
bar=None,
count=SimpleNamespace( count=SimpleNamespace(
all=0, all=0,
success=0, success=0,
@ -516,7 +508,6 @@ class XHS:
data := self._extract_data( data := self._extract_data(
namespace, namespace,
id_, id_,
log,
count, count,
) )
): ):
@ -527,8 +518,6 @@ class XHS:
id_, id_,
True, True,
index, index,
log,
bar,
count, count,
) )
@ -539,7 +528,6 @@ class XHS:
async def update_author_nickname( async def update_author_nickname(
self, self,
container: dict, container: dict,
log,
): ):
if a := self.CLEANER.filter_name( if a := self.CLEANER.filter_name(
self.mapping_data.get(i := container["作者ID"], "") self.mapping_data.get(i := container["作者ID"], "")
@ -550,7 +538,6 @@ class XHS:
await self.mapping.update_cache( await self.mapping.update_cache(
i, i,
container["作者昵称"], container["作者昵称"],
log,
) )
@staticmethod @staticmethod
@ -602,13 +589,10 @@ class XHS:
async def monitor( async def monitor(
self, self,
delay=1, delay=1,
download=False, download=True,
log=None, data=False,
bar=None,
data=True,
) -> None: ) -> None:
logging( self.logging(
None,
_( _(
"程序会自动读取并提取剪贴板中的小红书作品链接,并自动下载链接对应的作品文件,如需关闭,请点击关闭按钮,或者向剪贴板写入 “close” 文本!" "程序会自动读取并提取剪贴板中的小红书作品链接,并自动下载链接对应的作品文件,如需关闭,请点击关闭按钮,或者向剪贴板写入 “close” 文本!"
), ),
@ -618,7 +602,7 @@ class XHS:
copy("") copy("")
await gather( await gather(
self.__get_link(delay), self.__get_link(delay),
self.__receive_link(delay, download, None, log, bar, data), self.__receive_link(delay, download=download, index=None, data=data),
) )
async def __get_link(self, delay: int): async def __get_link(self, delay: int):
@ -988,3 +972,10 @@ class XHS:
await self.switch_script_server( await self.switch_script_server(
switch=self.manager.script_server, switch=self.manager.script_server,
) )
def logging(self, text, style=INFO):
logging(
self.print,
text,
style,
)

View File

@ -44,6 +44,7 @@ class Download:
manager: "Manager", manager: "Manager",
): ):
self.manager = manager self.manager = manager
self.print = manager.print
self.folder = manager.folder self.folder = manager.folder
self.temp = manager.temp self.temp = manager.temp
self.chunk = manager.chunk self.chunk = manager.chunk
@ -76,8 +77,6 @@ class Download:
filename: str, filename: str,
type_: str, type_: str,
mtime: int, mtime: int,
log,
bar,
) -> tuple[Path, list[Any]]: ) -> tuple[Path, list[Any]]:
path = self.__generate_path(nickname, filename) path = self.__generate_path(nickname, filename)
if type_ == _("视频"): if type_ == _("视频"):
@ -85,7 +84,6 @@ class Download:
urls, urls,
path, path,
filename, filename,
log,
) )
elif type_ in { elif type_ in {
_("图文"), _("图文"),
@ -97,7 +95,6 @@ class Download:
index, index,
path, path,
filename, filename,
log,
) )
else: else:
raise ValueError raise ValueError
@ -108,8 +105,6 @@ class Download:
name, name,
format_, format_,
mtime, mtime,
log,
bar,
) )
for url, name, format_ in tasks for url, name, format_ in tasks
] ]
@ -127,12 +122,18 @@ class Download:
return path return path
def __ready_download_video( def __ready_download_video(
self, urls: list[str], path: Path, name: str, log self,
urls: list[str],
path: Path,
name: str,
) -> list: ) -> list:
if not self.video_download: if not self.video_download:
logging(log, _("视频作品下载功能已关闭,跳过下载")) logging(self.print, _("视频作品下载功能已关闭,跳过下载"))
return [] return []
if self.__check_exists_path(path, f"{name}.{self.video_format}", log): if self.__check_exists_path(
path,
f"{name}.{self.video_format}",
):
return [] return []
return [(urls[0], name, self.video_format)] return [(urls[0], name, self.video_format)]
@ -143,11 +144,10 @@ class Download:
index: list | tuple | None, index: list | tuple | None,
path: Path, path: Path,
name: str, name: str,
log,
) -> list: ) -> list:
tasks = [] tasks = []
if not self.image_download: if not self.image_download:
logging(log, _("图文作品下载功能已关闭,跳过下载")) logging(self.print, _("图文作品下载功能已关闭,跳过下载"))
return tasks return tasks
for i, j in enumerate(zip(urls, lives), start=1): for i, j in enumerate(zip(urls, lives), start=1):
if index and i not in index: if index and i not in index:
@ -157,7 +157,6 @@ class Download:
self.__check_exists_path( self.__check_exists_path(
path, path,
f"{file}.{s}", f"{file}.{s}",
log,
) )
for s in self.image_format_list for s in self.image_format_list
): ):
@ -168,32 +167,29 @@ class Download:
or self.__check_exists_path( or self.__check_exists_path(
path, path,
f"{file}.{self.live_format}", f"{file}.{self.live_format}",
log,
) )
): ):
continue continue
tasks.append([j[1], file, self.live_format]) tasks.append([j[1], file, self.live_format])
return tasks return tasks
@staticmethod
def __check_exists_glob( def __check_exists_glob(
self,
path: Path, path: Path,
name: str, name: str,
log,
) -> bool: ) -> bool:
if any(path.glob(name)): if any(path.glob(name)):
logging(log, _("{0} 文件已存在,跳过下载").format(name)) logging(self.print, _("{0} 文件已存在,跳过下载").format(name))
return True return True
return False return False
@staticmethod
def __check_exists_path( def __check_exists_path(
self,
path: Path, path: Path,
name: str, name: str,
log,
) -> bool: ) -> bool:
if path.joinpath(name).exists(): if path.joinpath(name).exists():
logging(log, _("{0} 文件已存在,跳过下载").format(name)) logging(self.print, _("{0} 文件已存在,跳过下载").format(name))
return True return True
return False return False
@ -205,26 +201,9 @@ class Download:
name: str, name: str,
format_: str, format_: str,
mtime: int, mtime: int,
log,
bar,
): ):
async with self.SEMAPHORE: async with self.SEMAPHORE:
headers = self.headers.copy() headers = self.headers.copy()
# try:
# length, suffix = await self.__head_file(
# url,
# headers,
# format_,
# )
# except HTTPError as error:
# logging(
# log,
# _(
# "网络异常,{0} 请求失败,错误信息: {1}").format(name, repr(error)),
# ERROR,
# )
# return False
# temp = self.temp.joinpath(f"{name}.{suffix}")
temp = self.temp.joinpath(f"{name}.{format_}") temp = self.temp.joinpath(f"{name}.{format_}")
self.__update_headers_range( self.__update_headers_range(
headers, headers,
@ -258,7 +237,6 @@ class Download:
name, name,
# suffix, # suffix,
format_, format_,
log,
) )
self.manager.move( self.manager.move(
temp, temp,
@ -267,12 +245,12 @@ class Download:
self.write_mtime, self.write_mtime,
) )
# self.__create_progress(bar, None) # self.__create_progress(bar, None)
logging(log, _("文件 {0} 下载成功").format(real.name)) logging(self.print, _("文件 {0} 下载成功").format(real.name))
return True return True
except HTTPError as error: except HTTPError as error:
# self.__create_progress(bar, None) # self.__create_progress(bar, None)
logging( logging(
log, self.print,
_("网络异常,{0} 下载失败,错误信息: {1}").format( _("网络异常,{0} 下载失败,错误信息: {1}").format(
name, repr(error) name, repr(error)
), ),
@ -282,7 +260,7 @@ class Download:
except CacheError as error: except CacheError as error:
self.manager.delete(temp) self.manager.delete(temp)
logging( logging(
log, self.print,
str(error), str(error),
ERROR, ERROR,
) )
@ -335,13 +313,12 @@ class Download:
headers["Range"] = f"bytes={(p := self.__get_resume_byte_position(file))}-" headers["Range"] = f"bytes={(p := self.__get_resume_byte_position(file))}-"
return p return p
@staticmethod
async def __suffix_with_file( async def __suffix_with_file(
self,
temp: Path, temp: Path,
path: Path, path: Path,
name: str, name: str,
default_suffix: str, default_suffix: str,
log,
) -> Path: ) -> Path:
try: try:
async with open(temp, "rb") as f: async with open(temp, "rb") as f:
@ -351,7 +328,7 @@ class Download:
return path.joinpath(f"{name}.{suffix}") return path.joinpath(f"{name}.{suffix}")
except Exception as error: except Exception as error:
logging( logging(
log, self.print,
_("文件 {0} 格式判断失败,错误信息:{1}").format( _("文件 {0} 格式判断失败,错误信息:{1}").format(
temp.name, repr(error) temp.name, repr(error)
), ),

View File

@ -17,6 +17,7 @@ class Html:
self, self,
manager: "Manager", manager: "Manager",
): ):
self.print = manager.print
self.retry = manager.retry self.retry = manager.retry
self.client = manager.request_client self.client = manager.request_client
self.headers = manager.headers self.headers = manager.headers
@ -27,7 +28,6 @@ class Html:
self, self,
url: str, url: str,
content=True, content=True,
log=None,
cookie: str = None, cookie: str = None,
proxy: str = None, proxy: str = None,
**kwargs, **kwargs,
@ -62,7 +62,9 @@ class Html:
raise ValueError raise ValueError
except HTTPError as error: except HTTPError as error:
logging( logging(
log, _("网络异常,{0} 请求失败: {1}").format(url, repr(error)), ERROR self.print,
_("网络异常,{0} 请求失败: {1}").format(url, repr(error)),
ERROR,
) )
return "" return ""

View File

@ -6,7 +6,7 @@ system to be compatible with PyInstaller's frozen import mechanism.
The Problem: The Problem:
----------- -----------
Beartype's `beartype.claw` module installs a custom path hook into Beartype's `beartype.claw` module installs a custom path hook into
`sys.path_hooks` that uses `SourceFileLoader` to load and transform Python `sys.path_hooks` that uses `SourceFileLoader` to load and transform Python
source files. In PyInstaller's frozen environment: source files. In PyInstaller's frozen environment:
@ -36,26 +36,26 @@ import sys
def _is_pyinstaller_frozen(): def _is_pyinstaller_frozen():
"""Check if running in a PyInstaller frozen environment.""" """Check if running in a PyInstaller frozen environment."""
return getattr(sys, 'frozen', False) or hasattr(sys, '_MEIPASS') return getattr(sys, "frozen", False) or hasattr(sys, "_MEIPASS")
def _patch_beartype_claw(): def _patch_beartype_claw():
""" """
Patch beartype's add_beartype_pathhook to skip in frozen environments. Patch beartype's add_beartype_pathhook to skip in frozen environments.
This patches the function at import time before any user code can call it. This patches the function at import time before any user code can call it.
""" """
# Only patch if we're actually frozen # Only patch if we're actually frozen
if not _is_pyinstaller_frozen(): if not _is_pyinstaller_frozen():
return return
try: try:
# Import the module containing the function to patch # Import the module containing the function to patch
from beartype.claw._importlib import clawimpmain from beartype.claw._importlib import clawimpmain
# Store the original function # Store the original function
_original_add_beartype_pathhook = clawimpmain.add_beartype_pathhook _original_add_beartype_pathhook = clawimpmain.add_beartype_pathhook
def _patched_add_beartype_pathhook(): def _patched_add_beartype_pathhook():
""" """
Patched version of add_beartype_pathhook that skips in frozen env. Patched version of add_beartype_pathhook that skips in frozen env.
@ -66,15 +66,16 @@ def _patch_beartype_claw():
return return
# Otherwise, call the original # Otherwise, call the original
return _original_add_beartype_pathhook() return _original_add_beartype_pathhook()
# Replace the function in clawimpmain module # Replace the function in clawimpmain module
clawimpmain.add_beartype_pathhook = _patched_add_beartype_pathhook clawimpmain.add_beartype_pathhook = _patched_add_beartype_pathhook
# CRITICAL: Also patch clawpkgmain which does `from ... import add_beartype_pathhook` # CRITICAL: Also patch clawpkgmain which does `from ... import add_beartype_pathhook`
# and thus has its own local reference to the original function # and thus has its own local reference to the original function
from beartype.claw._package import clawpkgmain from beartype.claw._package import clawpkgmain
clawpkgmain.add_beartype_pathhook = _patched_add_beartype_pathhook clawpkgmain.add_beartype_pathhook = _patched_add_beartype_pathhook
except ImportError: except ImportError:
# beartype not installed or not using claw module # beartype not installed or not using claw module
pass pass
@ -85,4 +86,3 @@ def _patch_beartype_claw():
# Apply the patch when this runtime hook is loaded # Apply the patch when this runtime hook is loaded
_patch_beartype_claw() _patch_beartype_claw()

View File

@ -71,9 +71,10 @@ class Manager:
author_archive: bool, author_archive: bool,
write_mtime: bool, write_mtime: bool,
script_server: bool, script_server: bool,
_print: bool,
cleaner: "Cleaner", cleaner: "Cleaner",
print_object,
): ):
self.print = print_object
self.root = root self.root = root
self.cleaner = cleaner self.cleaner = cleaner
self.temp = root.joinpath("Temp") self.temp = root.joinpath("Temp")
@ -95,9 +96,7 @@ class Manager:
self.download_record = self.check_bool(download_record, True) self.download_record = self.check_bool(download_record, True)
self.proxy_tip = None self.proxy_tip = None
self.proxy = self.__check_proxy(proxy) self.proxy = self.__check_proxy(proxy)
self.print_proxy_tip( self.print_proxy_tip()
_print,
)
self.timeout = timeout self.timeout = timeout
self.request_client = AsyncClient( self.request_client = AsyncClient(
headers=self.headers headers=self.headers
@ -249,14 +248,13 @@ class Manager:
), ),
WARNING, WARNING,
) )
return None
def print_proxy_tip( def print_proxy_tip(
self, self,
_print: bool = True,
log=None,
) -> None: ) -> None:
if _print and self.proxy_tip: if self.proxy_tip:
logging(log, *self.proxy_tip) logging(self.print, *self.proxy_tip)
@classmethod @classmethod
def clean_cookie(cls, cookie_string: str) -> str: def clean_cookie(cls, cookie_string: str) -> str:

View File

@ -23,12 +23,12 @@ class Mapping:
self.folder_mode = manager.folder_mode self.folder_mode = manager.folder_mode
self.database = mapping self.database = mapping
self.switch = manager.author_archive self.switch = manager.author_archive
self.print = manager.print
async def update_cache( async def update_cache(
self, self,
id_: str, id_: str,
alias: str, alias: str,
log=None,
): ):
if not self.switch: if not self.switch:
return return
@ -37,7 +37,6 @@ class Mapping:
id_, id_,
alias, alias,
a, a,
log,
) )
await self.database.add(id_, alias) await self.database.add(id_, alias)
@ -49,11 +48,10 @@ class Mapping:
id_: str, id_: str,
alias: str, alias: str,
old_alias: str, old_alias: str,
log,
): ):
if not (old_folder := self.root.joinpath(f"{id_}_{old_alias}")).is_dir(): if not (old_folder := self.root.joinpath(f"{id_}_{old_alias}")).is_dir():
logging( logging(
log, self.print,
_("{old_folder} 文件夹不存在,跳过处理").format( _("{old_folder} 文件夹不存在,跳过处理").format(
old_folder=old_folder.name old_folder=old_folder.name
), ),
@ -63,13 +61,11 @@ class Mapping:
old_folder, old_folder,
id_, id_,
alias, alias,
log,
) )
self.__scan_file( self.__scan_file(
id_, id_,
alias, alias,
old_alias, old_alias,
log,
) )
def __rename_folder( def __rename_folder(
@ -77,17 +73,15 @@ class Mapping:
old_folder: Path, old_folder: Path,
id_: str, id_: str,
alias: str, alias: str,
log,
): ):
new_folder = self.root.joinpath(f"{id_}_{alias}") new_folder = self.root.joinpath(f"{id_}_{alias}")
self.__rename( self.__rename(
old_folder, old_folder,
new_folder, new_folder,
_("文件夹"), _("文件夹"),
log,
) )
logging( logging(
log, self.print,
_("文件夹 {old_folder} 已重命名为 {new_folder}").format( _("文件夹 {old_folder} 已重命名为 {new_folder}").format(
old_folder=old_folder.name, new_folder=new_folder.name old_folder=old_folder.name, new_folder=new_folder.name
), ),
@ -98,7 +92,6 @@ class Mapping:
old_: Path, old_: Path,
alias: str, alias: str,
old_alias: str, old_alias: str,
log,
) -> Path: ) -> Path:
if old_alias in old_.name: if old_alias in old_.name:
new_ = old_.parent / old_.name.replace(old_alias, alias, 1) new_ = old_.parent / old_.name.replace(old_alias, alias, 1)
@ -106,10 +99,9 @@ class Mapping:
old_, old_,
new_, new_,
_("文件夹"), _("文件夹"),
log,
) )
logging( logging(
log, self.print,
_("文件夹 {old_} 重命名为 {new_}").format( _("文件夹 {old_} 重命名为 {new_}").format(
old_=old_.name, new_=new_.name old_=old_.name, new_=new_.name
), ),
@ -122,7 +114,6 @@ class Mapping:
id_: str, id_: str,
alias: str, alias: str,
old_alias: str, old_alias: str,
log,
): ):
root = self.root.joinpath(f"{id_}_{alias}") root = self.root.joinpath(f"{id_}_{alias}")
item_list = root.iterdir() item_list = root.iterdir()
@ -133,7 +124,6 @@ class Mapping:
f, f,
alias, alias,
old_alias, old_alias,
log,
) )
files = f.iterdir() files = f.iterdir()
self.__batch_rename( self.__batch_rename(
@ -141,7 +131,6 @@ class Mapping:
files, files,
alias, alias,
old_alias, old_alias,
log,
) )
else: else:
self.__batch_rename( self.__batch_rename(
@ -149,7 +138,6 @@ class Mapping:
item_list, item_list,
alias, alias,
old_alias, old_alias,
log,
) )
def __batch_rename( def __batch_rename(
@ -158,7 +146,6 @@ class Mapping:
files, files,
alias: str, alias: str,
old_alias: str, old_alias: str,
log,
): ):
for old_file in files: for old_file in files:
if old_alias not in old_file.name: if old_alias not in old_file.name:
@ -168,7 +155,6 @@ class Mapping:
old_file, old_file,
alias, alias,
old_alias, old_alias,
log,
) )
def __rename_file( def __rename_file(
@ -177,36 +163,33 @@ class Mapping:
old_file: Path, old_file: Path,
alias: str, alias: str,
old_alias: str, old_alias: str,
log,
): ):
new_file = root.joinpath(old_file.name.replace(old_alias, alias, 1)) new_file = root.joinpath(old_file.name.replace(old_alias, alias, 1))
self.__rename( self.__rename(
old_file, old_file,
new_file, new_file,
_("文件"), _("文件"),
log,
) )
logging( logging(
log, self.print,
_("文件 {old_file} 重命名为 {new_file}").format( _("文件 {old_file} 重命名为 {new_file}").format(
old_file=old_file.name, new_file=new_file.name old_file=old_file.name, new_file=new_file.name
), ),
) )
return True return True
@staticmethod
def __rename( def __rename(
self,
old_: Path, old_: Path,
new_: Path, new_: Path,
type_=_("文件"), type_=_("文件"),
log=None,
) -> bool: ) -> bool:
try: try:
old_.rename(new_) old_.rename(new_)
return True return True
except PermissionError as e: except PermissionError as e:
logging( logging(
log, self.print,
_("{type} {old}被占用,重命名失败: {error}").format( _("{type} {old}被占用,重命名失败: {error}").format(
type=type_, old=old_.name, error=e type=type_, old=old_.name, error=e
), ),
@ -215,7 +198,7 @@ class Mapping:
return False return False
except FileExistsError as e: except FileExistsError as e:
logging( logging(
log, self.print,
_("{type} {new}名称重复,重命名失败: {error}").format( _("{type} {new}名称重复,重命名失败: {error}").format(
type=type_, new=new_.name, error=e type=type_, new=new_.name, error=e
), ),
@ -224,7 +207,7 @@ class Mapping:
return False return False
except OSError as e: except OSError as e:
logging( logging(
log, self.print,
_("处理{type} {old}时发生预期之外的错误: {error}").format( _("处理{type} {old}时发生预期之外的错误: {error}").format(
type=type_, old=old_.name, error=e type=type_, old=old_.name, error=e
), ),

View File

@ -1,5 +1,6 @@
from asyncio import sleep from asyncio import sleep
from random import uniform from random import uniform
from typing import Callable
from rich import print from rich import print
from rich.text import Text from rich.text import Text
@ -37,15 +38,16 @@ def retry_limited(function):
return inner return inner
def logging(log, text, style=INFO): def logging(log: Callable, text, style=INFO):
string = Text(text, style=style) string = Text(text, style=style)
if log: func = log()
log.write( if func is print:
func(string)
else:
func.write(
string, string,
scroll_end=True, scroll_end=True,
) )
else:
print(string)
async def sleep_time( async def sleep_time(