diff --git a/README.md b/README.md
index 0d980a8..12927e2 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,8 @@
✅ 支持 API 调用功能
✅ 支持文件断点续传下载
✅ 智能识别作品文件类型
+✅ 支持设置作者备注
+✅ 自动更新作者昵称
Script Features
- ✅ Download RedNote watermark-free works files
@@ -567,8 +569,8 @@ repository to execute the build process
✨ Other Open Source Projects by the Author:
💰 Sponsor
diff --git a/source/TUI/about.py b/source/TUI/about.py
index 85424e7..cabf77e 100644
--- a/source/TUI/about.py
+++ b/source/TUI/about.py
@@ -2,16 +2,13 @@ from rich.text import Text
from textual.app import ComposeResult
from textual.binding import Binding
from textual.screen import Screen
-from textual.widgets import Footer
-from textual.widgets import Header
-from textual.widgets import Label
-from textual.widgets import Link
+from textual.widgets import Footer, Header, Label, Link
from ..module import (
+ INFO,
+ MASTER,
PROJECT,
PROMPT,
- MASTER,
- INFO,
)
from ..translation import _
@@ -26,7 +23,7 @@ class About(Screen):
]
def __init__(
- self,
+ self,
):
super().__init__()
diff --git a/source/TUI/index.py b/source/TUI/index.py
index df8a5cb..ae78468 100644
--- a/source/TUI/index.py
+++ b/source/TUI/index.py
@@ -1,33 +1,25 @@
from pyperclip import paste
from rich.text import Text
-from textual import on
-from textual import work
+from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
-from textual.containers import HorizontalScroll
-from textual.containers import ScrollableContainer
+from textual.containers import HorizontalScroll, ScrollableContainer
from textual.screen import Screen
-from textual.widgets import Button
-from textual.widgets import Footer
-from textual.widgets import Header
-from textual.widgets import Input
-from textual.widgets import Label
-from textual.widgets import Link
-from textual.widgets import RichLog
+from textual.widgets import Button, Footer, Header, Input, Label, Link, RichLog
-from .monitor import Monitor
from ..application import XHS
from ..module import (
+ ERROR,
+ GENERAL,
+ LICENCE,
+ MASTER,
PROJECT,
PROMPT,
- MASTER,
- ERROR,
- WARNING,
- LICENCE,
REPOSITORY,
- GENERAL,
+ WARNING,
)
from ..translation import _
+from .monitor import Monitor
__all__ = ["Index"]
@@ -43,8 +35,8 @@ class Index(Screen):
]
def __init__(
- self,
- app: XHS,
+ self,
+ app: XHS,
):
super().__init__()
self.xhs = app
@@ -119,12 +111,12 @@ class Index(Screen):
async def deal(self):
await self.app.push_screen("loading")
if any(
- await self.xhs.extract(
- self.url.value,
- True,
- log=self.tip,
- data=False,
- )
+ await self.xhs.extract(
+ self.url.value,
+ True,
+ log=self.tip,
+ data=False,
+ )
):
self.url.value = ""
else:
diff --git a/source/TUI/loading.py b/source/TUI/loading.py
index 585e912..7c954af 100644
--- a/source/TUI/loading.py
+++ b/source/TUI/loading.py
@@ -1,8 +1,7 @@
from textual.app import ComposeResult
from textual.containers import Grid
from textual.screen import ModalScreen
-from textual.widgets import Label
-from textual.widgets import LoadingIndicator
+from textual.widgets import Label, LoadingIndicator
from ..translation import _
@@ -11,7 +10,7 @@ __all__ = ["Loading"]
class Loading(ModalScreen):
def __init__(
- self,
+ self,
):
super().__init__()
diff --git a/source/TUI/monitor.py b/source/TUI/monitor.py
index 7460a85..bfc9662 100644
--- a/source/TUI/monitor.py
+++ b/source/TUI/monitor.py
@@ -1,20 +1,15 @@
from rich.text import Text
-from textual import on
-from textual import work
+from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.screen import Screen
-from textual.widgets import Button
-from textual.widgets import Footer
-from textual.widgets import Header
-from textual.widgets import Label
-from textual.widgets import RichLog
+from textual.widgets import Button, Footer, Header, Label, RichLog
from ..application import XHS
from ..module import (
- PROJECT,
- MASTER,
INFO,
+ MASTER,
+ PROJECT,
)
from ..translation import _
@@ -28,8 +23,8 @@ class Monitor(Screen):
]
def __init__(
- self,
- app: XHS,
+ self,
+ app: XHS,
):
super().__init__()
self.xhs = app
diff --git a/source/TUI/record.py b/source/TUI/record.py
index e607e7b..d568b9e 100644
--- a/source/TUI/record.py
+++ b/source/TUI/record.py
@@ -1,11 +1,8 @@
from textual import on
from textual.app import ComposeResult
-from textual.containers import Grid
-from textual.containers import HorizontalScroll
+from textual.containers import Grid, HorizontalScroll
from textual.screen import ModalScreen
-from textual.widgets import Button
-from textual.widgets import Input
-from textual.widgets import Label
+from textual.widgets import Button, Input, Label
from ..application import XHS
from ..translation import _
@@ -15,8 +12,8 @@ __all__ = ["Record"]
class Record(ModalScreen):
def __init__(
- self,
- app: XHS,
+ self,
+ app: XHS,
):
super().__init__()
self.xhs = app
diff --git a/source/TUI/setting.py b/source/TUI/setting.py
index 8bdad3b..28d3a25 100644
--- a/source/TUI/setting.py
+++ b/source/TUI/setting.py
@@ -1,16 +1,9 @@
from textual import on
from textual.app import ComposeResult
from textual.binding import Binding
-from textual.containers import Container
-from textual.containers import ScrollableContainer
+from textual.containers import Container, ScrollableContainer
from textual.screen import Screen
-from textual.widgets import Button
-from textual.widgets import Checkbox
-from textual.widgets import Footer
-from textual.widgets import Header
-from textual.widgets import Input
-from textual.widgets import Label
-from textual.widgets import Select
+from textual.widgets import Button, Checkbox, Footer, Header, Input, Label, Select
from ..translation import _
@@ -217,6 +210,7 @@ class Setting(Screen):
def save_settings(self):
self.dismiss(
{
+ "mapping_data": self.data.get("mapping_data", {}),
"work_path": self.query_one("#work_path").value,
"folder_name": self.query_one("#folder_name").value,
"name_format": self.query_one("#name_format").value,
diff --git a/source/TUI/update.py b/source/TUI/update.py
index dab1fa9..61bfd7c 100644
--- a/source/TUI/update.py
+++ b/source/TUI/update.py
@@ -2,8 +2,7 @@ from textual import work
from textual.app import ComposeResult
from textual.containers import Grid
from textual.screen import ModalScreen
-from textual.widgets import Label
-from textual.widgets import LoadingIndicator
+from textual.widgets import Label, LoadingIndicator
from ..application import XHS
from ..module import (
@@ -16,8 +15,8 @@ __all__ = ["Update"]
class Update(ModalScreen):
def __init__(
- self,
- app: XHS,
+ self,
+ app: XHS,
):
super().__init__()
self.xhs = app
@@ -79,7 +78,7 @@ class Update(ModalScreen):
@staticmethod
def compare_versions(
- current_version: str, target_version: str, is_development: bool
+ current_version: str, target_version: str, is_development: bool
) -> int:
current_major, current_minor = map(int, current_version.split("."))
target_major, target_minor = map(int, target_version.split("."))
diff --git a/source/application/app.py b/source/application/app.py
index 0e91615..e15b187 100644
--- a/source/application/app.py
+++ b/source/application/app.py
@@ -1,9 +1,4 @@
-from asyncio import Event
-from asyncio import Queue
-from asyncio import QueueEmpty
-from asyncio import create_task
-from asyncio import gather
-from asyncio import sleep
+from asyncio import Event, Queue, QueueEmpty, create_task, gather, sleep
from contextlib import suppress
from datetime import datetime
from re import compile
@@ -11,37 +6,40 @@ from urllib.parse import urlparse
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
-from pyperclip import copy
# from aiohttp import web
-from pyperclip import paste
-from uvicorn import Config
-from uvicorn import Server
+from pyperclip import copy, paste
+from uvicorn import Config, Server
-from source.expansion import BrowserCookie
-from source.expansion import Cleaner
-from source.expansion import Converter
-from source.expansion import Namespace
-from source.expansion import beautify_string
-from source.module import DataRecorder
-from source.module import ExtractData
-from source.module import ExtractParams
-from source.module import IDRecorder
-from source.module import Manager
+from source.expansion import (
+ BrowserCookie,
+ Cleaner,
+ Converter,
+ Namespace,
+ beautify_string,
+)
from source.module import (
- ROOT,
+ __VERSION__,
ERROR,
- WARNING,
MASTER,
REPOSITORY,
+ ROOT,
+ VERSION_BETA,
VERSION_MAJOR,
VERSION_MINOR,
- VERSION_BETA,
- __VERSION__,
+ WARNING,
+ DataRecorder,
+ ExtractData,
+ ExtractParams,
+ IDRecorder,
+ Manager,
+ MapRecorder,
+ logging,
+ sleep_time,
)
-from source.module import logging
-from source.module import sleep_time
-from source.translation import switch_language, _
+from source.translation import _, switch_language
+
+from ..module import Mapping
from .download import Download
from .explore import Explore
from .image import Image
@@ -87,6 +85,7 @@ class XHS:
def __init__(
self,
+ mapping_data: dict = None,
work_path="",
folder_name="Download",
name_format="发布时间 作者昵称 作品标题",
@@ -132,6 +131,11 @@ class XHS:
author_archive,
_print,
)
+ self.mapping_data = mapping_data or {}
+ self.map_recorder = MapRecorder(
+ self.manager,
+ )
+ self.mapping = Mapping(self.manager, self.map_recorder)
self.html = Html(self.manager)
self.image = Image()
self.video = Video()
@@ -309,11 +313,29 @@ class XHS:
self.__extract_image(data, namespace)
else:
data["下载地址"] = []
+ await self.update_author_nickname(data, log)
await self.__download_files(data, download, index, log, bar)
logging(log, _("作品处理完成:{0}").format(i))
await sleep_time()
return data
+ async def update_author_nickname(
+ self,
+ container: dict,
+ log,
+ ):
+ if a := self.CLEANER.filter_name(
+ self.mapping_data.get(i := container["作者ID"], "")
+ ):
+ container["作者昵称"] = a
+ else:
+ container["作者昵称"] = self.manager.filter_name(container["作者昵称"]) or i
+ await self.mapping.update_cache(
+ i,
+ container["作者昵称"],
+ log,
+ )
+
@staticmethod
def __extract_link_id(url: str) -> str:
link = urlparse(url)
@@ -330,8 +352,6 @@ class XHS:
match key:
case "发布时间":
values.append(self.__get_name_time(data))
- case "作者昵称":
- values.append(self.__get_name_author(data))
case "作品标题":
values.append(self.__get_name_title(data))
case _:
@@ -353,9 +373,6 @@ class XHS:
def __get_name_time(data: dict) -> str:
return data["发布时间"].replace(":", ".")
- def __get_name_author(self, data: dict) -> str:
- return self.manager.filter_name(data["作者昵称"]) or data["作者ID"]
-
def __get_name_title(self, data: dict) -> str:
return (
beautify_string(
@@ -419,11 +436,13 @@ class XHS:
async def __aenter__(self):
await self.id_recorder.__aenter__()
await self.data_recorder.__aenter__()
+ await self.map_recorder.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.id_recorder.__aexit__(exc_type, exc_value, traceback)
await self.data_recorder.__aexit__(exc_type, exc_value, traceback)
+ await self.map_recorder.__aexit__(exc_type, exc_value, traceback)
await self.close()
async def close(self):
diff --git a/source/application/download.py b/source/application/download.py
index 8a59124..d2d20c5 100644
--- a/source/application/download.py
+++ b/source/application/download.py
@@ -1,5 +1,4 @@
-from asyncio import Semaphore
-from asyncio import gather
+from asyncio import Semaphore, gather
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -7,22 +6,24 @@ from aiofiles import open
from httpx import HTTPError
from ..expansion import CacheError
-from ..module import ERROR
-from ..module import (
- FILE_SIGNATURES_LENGTH,
- FILE_SIGNATURES,
-)
-from ..module import MAX_WORKERS
+
# from ..module import WARNING
-from ..module import Manager
-from ..module import logging
+from ..module import (
+ ERROR,
+ FILE_SIGNATURES,
+ FILE_SIGNATURES_LENGTH,
+ MAX_WORKERS,
+ logging,
+ sleep_time,
+)
from ..module import retry as re_download
-from ..module import sleep_time
from ..translation import _
if TYPE_CHECKING:
from httpx import AsyncClient
+ from ..module import Manager
+
__all__ = ["Download"]
@@ -39,8 +40,8 @@ class Download:
}
def __init__(
- self,
- manager: Manager,
+ self,
+ manager: "Manager",
):
self.manager = manager
self.folder = manager.folder
@@ -66,15 +67,15 @@ class Download:
self.author_archive = manager.author_archive
async def run(
- self,
- urls: list,
- lives: list,
- index: list | tuple | None,
- nickname: str,
- filename: str,
- type_: str,
- log,
- bar,
+ self,
+ urls: list,
+ lives: list,
+ index: list | tuple | None,
+ nickname: str,
+ filename: str,
+ type_: str,
+ log,
+ bar,
) -> tuple[Path, list[Any]]:
path = self.__generate_path(nickname, filename)
if type_ == _("视频"):
@@ -109,7 +110,7 @@ class Download:
tasks = await gather(*tasks)
return path, tasks
- def __generate_path(self, nickname:str, filename: str):
+ def __generate_path(self, nickname: str, filename: str):
if self.author_archive:
folder = self.folder.joinpath(nickname)
folder.mkdir(exist_ok=True)
@@ -120,7 +121,7 @@ class Download:
return path
def __ready_download_video(
- self, urls: list[str], path: Path, name: str, log
+ self, urls: list[str], path: Path, name: str, log
) -> list:
if not self.video_download:
logging(log, _("视频作品下载功能已关闭,跳过下载"))
@@ -130,13 +131,13 @@ class Download:
return [(urls[0], name, self.video_format)]
def __ready_download_image(
- self,
- urls: list[str],
- lives: list[str],
- index: list | tuple | None,
- path: Path,
- name: str,
- log,
+ self,
+ urls: list[str],
+ lives: list[str],
+ index: list | tuple | None,
+ path: Path,
+ name: str,
+ log,
) -> list:
tasks = []
if not self.image_download:
@@ -147,32 +148,32 @@ class Download:
continue
file = f"{name}_{i}"
if not any(
- self.__check_exists_path(
- path,
- f"{file}.{s}",
- log,
- )
- for s in self.image_format_list
+ self.__check_exists_path(
+ path,
+ f"{file}.{s}",
+ log,
+ )
+ for s in self.image_format_list
):
tasks.append([j[0], file, self.image_format])
if (
- not self.live_download
- or not j[1]
- or self.__check_exists_path(
+ not self.live_download
+ or not j[1]
+ or self.__check_exists_path(
path,
f"{file}.{self.live_format}",
log,
- )
+ )
):
continue
tasks.append([j[1], file, self.live_format])
return tasks
def __check_exists_glob(
- self,
- path: Path,
- name: str,
- log,
+ self,
+ path: Path,
+ name: str,
+ log,
) -> bool:
if any(path.glob(name)):
logging(log, _("{0} 文件已存在,跳过下载").format(name))
@@ -180,10 +181,10 @@ class Download:
return False
def __check_exists_path(
- self,
- path: Path,
- name: str,
- log,
+ self,
+ path: Path,
+ name: str,
+ log,
) -> bool:
if path.joinpath(name).exists():
logging(log, _("{0} 文件已存在,跳过下载").format(name))
@@ -192,13 +193,13 @@ class Download:
@re_download
async def __download(
- self,
- url: str,
- path: Path,
- name: str,
- format_: str,
- log,
- bar,
+ self,
+ url: str,
+ path: Path,
+ name: str,
+ format_: str,
+ log,
+ bar,
):
async with self.SEMAPHORE:
headers = self.headers.copy()
@@ -224,9 +225,9 @@ class Download:
)
try:
async with self.client.stream(
- "GET",
- url,
- headers=headers,
+ "GET",
+ url,
+ headers=headers,
) as response:
await sleep_time()
if response.status_code == 416:
@@ -276,9 +277,9 @@ class Download:
@staticmethod
def __create_progress(
- bar,
- total: int | None,
- completed=0,
+ bar,
+ total: int | None,
+ completed=0,
):
if bar:
bar.update(total=total, completed=completed)
@@ -293,10 +294,10 @@ class Download:
return cls.CONTENT_TYPE_MAP.get(content, "")
async def __head_file(
- self,
- url: str,
- headers: dict[str, str],
- suffix: str,
+ self,
+ url: str,
+ headers: dict[str, str],
+ suffix: str,
) -> tuple[int, str]:
response = await self.client.head(
url,
@@ -313,26 +314,26 @@ class Download:
return file.stat().st_size if file.is_file() else 0
def __update_headers_range(
- self,
- headers: dict[str, str],
- file: Path,
+ self,
+ headers: dict[str, str],
+ file: Path,
) -> int:
headers["Range"] = f"bytes={(p := self.__get_resume_byte_position(file))}-"
return p
async def __suffix_with_file(
- self,
- temp: Path,
- path: Path,
- name: str,
- default_suffix: str,
- log,
+ self,
+ temp: Path,
+ path: Path,
+ name: str,
+ default_suffix: str,
+ log,
) -> Path:
try:
async with open(temp, "rb") as f:
file_start = await f.read(FILE_SIGNATURES_LENGTH)
for offset, signature, suffix in FILE_SIGNATURES:
- if file_start[offset: offset + len(signature)] == signature:
+ if file_start[offset : offset + len(signature)] == signature:
return path.joinpath(f"{name}.{suffix}")
except Exception as error:
logging(
diff --git a/source/application/image.py b/source/application/image.py
index c6f8cf5..2e2b0a3 100644
--- a/source/application/image.py
+++ b/source/application/image.py
@@ -1,4 +1,5 @@
from source.expansion import Namespace
+
from .request import Html
__all__ = ["Image"]
@@ -37,8 +38,8 @@ class Image:
@staticmethod
def __generate_fixed_link(
- token: str,
- format_: str,
+ token: str,
+ format_: str,
) -> str:
return f"https://ci.xiaohongshu.com/{token}?imageView2/format/{format_}"
@@ -50,10 +51,10 @@ class Image:
def __get_live_link(items: list) -> list:
return [
(
- Html.format_url(
- Namespace.object_extract(item, "stream.h264[0].masterUrl")
- )
- or None
+ Html.format_url(
+ Namespace.object_extract(item, "stream.h264[0].masterUrl")
+ )
+ or None
)
for item in items
]
diff --git a/source/application/request.py b/source/application/request.py
index b5c5765..d937553 100644
--- a/source/application/request.py
+++ b/source/application/request.py
@@ -1,19 +1,20 @@
+from typing import TYPE_CHECKING
+
from httpx import HTTPError
-from ..module import ERROR
-from ..module import Manager
-from ..module import logging
-from ..module import retry
-from ..module import sleep_time
+from ..module import ERROR, Manager, logging, retry, sleep_time
from ..translation import _
+if TYPE_CHECKING:
+ from ..module import Manager
+
__all__ = ["Html"]
class Html:
def __init__(
- self,
- manager: Manager,
+ self,
+ manager: "Manager",
):
self.retry = manager.retry
self.client = manager.request_client
@@ -21,12 +22,12 @@ class Html:
@retry
async def request_url(
- self,
- url: str,
- content=True,
- log=None,
- cookie: str = None,
- **kwargs,
+ self,
+ url: str,
+ content=True,
+ log=None,
+ cookie: str = None,
+ **kwargs,
) -> str:
headers = self.update_cookie(
cookie,
@@ -63,16 +64,16 @@ class Html:
return bytes(url, "utf-8").decode("unicode_escape")
def update_cookie(
- self,
- cookie: str = None,
+ self,
+ cookie: str = None,
) -> dict:
return self.headers | {"Cookie": cookie} if cookie else self.headers.copy()
async def __request_url_head(
- self,
- url: str,
- headers: dict,
- **kwargs,
+ self,
+ url: str,
+ headers: dict,
+ **kwargs,
):
return await self.client.head(
url,
@@ -81,10 +82,10 @@ class Html:
)
async def __request_url_get(
- self,
- url: str,
- headers: dict,
- **kwargs,
+ self,
+ url: str,
+ headers: dict,
+ **kwargs,
):
return await self.client.get(
url,
diff --git a/source/expansion/browser.py b/source/expansion/browser.py
index 33e08e3..f80b273 100644
--- a/source/expansion/browser.py
+++ b/source/expansion/browser.py
@@ -39,9 +39,9 @@ class BrowserCookie:
@classmethod
def run(
- cls,
- domains: list[str],
- console: Console = None,
+ cls,
+ domains: list[str],
+ console: Console = None,
) -> str:
console = console or Console()
options = "\n".join(
@@ -49,11 +49,11 @@ class BrowserCookie:
for i, (k, v) in enumerate(cls.SUPPORT_BROWSER.items(), start=1)
)
if browser := console.input(
- _(
- "读取指定浏览器的 Cookie 并写入配置文件\n"
- "Windows 系统需要以管理员身份运行程序才能读取 Chromium、Chrome、Edge 浏览器 Cookie!\n"
- "{options}\n请输入浏览器名称或序号:"
- ).format(options=options),
+ _(
+ "读取指定浏览器的 Cookie 并写入配置文件\n"
+ "Windows 系统需要以管理员身份运行程序才能读取 Chromium、Chrome、Edge 浏览器 Cookie!\n"
+ "{options}\n请输入浏览器名称或序号:"
+ ).format(options=options),
):
return cls.get(
browser,
@@ -64,10 +64,10 @@ class BrowserCookie:
@classmethod
def get(
- cls,
- browser: str | int,
- domains: list[str],
- console: Console = None,
+ cls,
+ browser: str | int,
+ domains: list[str],
+ console: Console = None,
) -> str:
console = console or Console()
if not (browser := cls.__browser_object(browser)):
diff --git a/source/expansion/cleaner.py b/source/expansion/cleaner.py
index 154cc09..6f3c7b5 100644
--- a/source/expansion/cleaner.py
+++ b/source/expansion/cleaner.py
@@ -68,10 +68,10 @@ class Cleaner:
return text
def filter_name(
- self,
- text: str,
- replace: str = "",
- default: str = "",
+ self,
+ text: str,
+ replace: str = "",
+ default: str = "",
) -> str:
"""过滤文件夹名称中的非法字符"""
text = text.replace(":", ".")
@@ -98,9 +98,9 @@ class Cleaner:
@classmethod
def remove_control_characters(
- cls,
- text,
- replace="",
+ cls,
+ text,
+ replace="",
):
# 使用正则表达式匹配所有控制字符
return cls.CONTROL_CHARACTERS.sub(
diff --git a/source/expansion/file_folder.py b/source/expansion/file_folder.py
index acb3adc..a2d183b 100644
--- a/source/expansion/file_folder.py
+++ b/source/expansion/file_folder.py
@@ -16,7 +16,7 @@ def remove_empty_directories(path: Path) -> None:
"\\__",
}
for dir_path, dir_names, file_names in path.walk(
- top_down=False,
+ top_down=False,
):
if any(i in str(dir_path) for i in exclude):
continue
diff --git a/source/expansion/namespace.py b/source/expansion/namespace.py
index 08017c6..c490008 100644
--- a/source/expansion/namespace.py
+++ b/source/expansion/namespace.py
@@ -24,17 +24,17 @@ class Namespace:
return depth_conversion(data)
def safe_extract(
- self,
- attribute_chain: str,
- default: Union[str, int, list, dict, SimpleNamespace] = "",
+ self,
+ attribute_chain: str,
+ default: Union[str, int, list, dict, SimpleNamespace] = "",
):
return self.__safe_extract(self.data, attribute_chain, default)
@staticmethod
def __safe_extract(
- data_object: SimpleNamespace,
- attribute_chain: str,
- default: Union[str, int, list, dict, SimpleNamespace] = "",
+ data_object: SimpleNamespace,
+ attribute_chain: str,
+ default: Union[str, int, list, dict, SimpleNamespace] = "",
):
data = deepcopy(data_object)
attributes = attribute_chain.split(".")
@@ -56,10 +56,10 @@ class Namespace:
@classmethod
def object_extract(
- cls,
- data_object: SimpleNamespace,
- attribute_chain: str,
- default: Union[str, int, list, dict, SimpleNamespace] = "",
+ cls,
+ data_object: SimpleNamespace,
+ attribute_chain: str,
+ default: Union[str, int, list, dict, SimpleNamespace] = "",
):
return cls.__safe_extract(
data_object,
diff --git a/source/module/__init__.py b/source/module/__init__.py
index 7bbd59f..f215249 100644
--- a/source/module/__init__.py
+++ b/source/module/__init__.py
@@ -6,6 +6,8 @@ from .model import (
)
from .recorder import DataRecorder
from .recorder import IDRecorder
+from .recorder import MapRecorder
+from .mapping import Mapping
from .settings import Settings
from .static import (
VERSION_MAJOR,
@@ -35,4 +37,5 @@ from .tools import (
retry,
logging,
sleep_time,
+ retry_limited,
)
diff --git a/source/module/manager.py b/source/module/manager.py
index c25b43e..0bbe2ae 100644
--- a/source/module/manager.py
+++ b/source/module/manager.py
@@ -1,22 +1,21 @@
from pathlib import Path
-from re import compile
-from re import sub
-from shutil import move
-from shutil import rmtree
+from re import compile, sub
+from shutil import move, rmtree
-from httpx import AsyncClient
-from httpx import AsyncHTTPTransport
-from httpx import HTTPStatusError
-from httpx import RequestError
-from httpx import TimeoutException
-from httpx import get
+from httpx import (
+ AsyncClient,
+ AsyncHTTPTransport,
+ HTTPStatusError,
+ RequestError,
+ TimeoutException,
+ get,
+)
from source.expansion import remove_empty_directories
-from .static import HEADERS
-from .static import USERAGENT
-from .static import WARNING
-from .tools import logging
+
from ..translation import _
+from .static import HEADERS, USERAGENT, WARNING
+from .tools import logging
__all__ = ["Manager"]
@@ -47,26 +46,26 @@ class Manager:
WEB_SESSION = r"(?:^|; )web_session=[^;]+"
def __init__(
- self,
- root: Path,
- path: str,
- folder: str,
- name_format: str,
- chunk: int,
- user_agent: str,
- cookie: str,
- proxy: str | dict,
- timeout: int,
- retry: int,
- record_data: bool,
- image_format: str,
- image_download: bool,
- video_download: bool,
- live_download: bool,
- download_record: bool,
- folder_mode: bool,
- author_archive:bool,
- _print: bool,
+ self,
+ root: Path,
+ path: str,
+ folder: str,
+ name_format: str,
+ chunk: int,
+ user_agent: str,
+ cookie: str,
+ proxy: str | dict,
+ timeout: int,
+ retry: int,
+ record_data: bool,
+ image_format: str,
+ image_download: bool,
+ video_download: bool,
+ live_download: bool,
+ download_record: bool,
+ folder_mode: bool,
+ author_archive: bool,
+ _print: bool,
):
self.root = root
self.temp = root.joinpath("./temp")
@@ -92,8 +91,8 @@ class Manager:
)
self.request_client = AsyncClient(
headers=self.headers
- | {
- "referer": "https://www.xiaohongshu.com/",
+ | {
+ "referer": "https://www.xiaohongshu.com/",
},
timeout=timeout,
verify=False,
@@ -194,9 +193,9 @@ class Manager:
)
def __check_proxy(
- self,
- proxy: str,
- url="https://www.xiaohongshu.com/explore",
+ self,
+ proxy: str,
+ url="https://www.xiaohongshu.com/explore",
) -> str | None:
if proxy:
try:
@@ -217,8 +216,8 @@ class Manager:
WARNING,
)
except (
- RequestError,
- HTTPStatusError,
+ RequestError,
+ HTTPStatusError,
) as e:
self.proxy_tip = (
_("代理 {0} 测试失败:{1}").format(
@@ -229,9 +228,9 @@ class Manager:
)
def print_proxy_tip(
- self,
- _print: bool = True,
- log=None,
+ self,
+ _print: bool = True,
+ log=None,
) -> None:
if _print and self.proxy_tip:
logging(log, *self.proxy_tip)
diff --git a/source/module/mapping.py b/source/module/mapping.py
new file mode 100644
index 0000000..0441c4c
--- /dev/null
+++ b/source/module/mapping.py
@@ -0,0 +1,236 @@
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+from ..translation import _
+from .static import ERROR
+from .tools import logging
+
+if TYPE_CHECKING:
+ from manager import Manager
+ from recorder import MapRecorder
+
+
+__all__ = ["Mapping"]
+
+
+class Mapping:
+ def __init__(
+ self,
+ manager: "Manager",
+ mapping: "MapRecorder",
+ ):
+ self.root = manager.folder
+ self.folder_mode = manager.folder_mode
+ self.database = mapping
+ self.switch = manager.author_archive
+
+ async def update_cache(
+ self,
+ id_: str,
+ alias: str,
+ log=None,
+ ):
+ if not self.switch:
+ return
+ if (a := await self.has_mapping(id_)) and a != alias:
+ self.__check_file(
+ id_,
+ alias,
+ a,
+ log,
+ )
+ await self.database.add(
+ id_,
+ alias,
+ )
+
+ async def has_mapping(self, id_: str) -> str:
+ return d[0] if (d := await self.database.select(id_)) else ""
+
+ def __check_file(
+ self,
+ id_: str,
+ alias: str,
+ old_alias: str,
+ log,
+ ):
+ if not (old_folder := self.root.joinpath(f"{id_}_{old_alias}")).is_dir():
+ logging(
+ log,
+ _("{old_folder} 文件夹不存在,跳过处理").format(
+ old_folder=old_folder.name
+ ),
+ )
+ return
+ self.__rename_folder(
+ old_folder,
+ id_,
+ alias,
+ log,
+ )
+ self.__scan_file(
+ id_,
+ alias,
+ old_alias,
+ log,
+ )
+
+ def __rename_folder(
+ self,
+ old_folder: Path,
+ id_: str,
+ alias: str,
+ log,
+ ):
+ new_folder = self.root.joinpath(f"{id_}_{alias}")
+ self.__rename(
+ old_folder,
+ new_folder,
+ _("文件夹"),
+ log,
+ )
+ logging(
+ log,
+ _("文件夹 {old_folder} 已重命名为 {new_folder}").format(
+ old_folder=old_folder.name, new_folder=new_folder.name
+ ),
+ )
+
+ def __rename_works_folder(
+ self,
+ old_: Path,
+ alias: str,
+ old_alias: str,
+ log,
+ ) -> Path:
+ if old_alias in old_.name:
+ new_ = old_.parent / old_.name.replace(old_alias, alias, 1)
+ self.__rename(
+ old_,
+ new_,
+ _("文件夹"),
+ log,
+ )
+ logging(
+ log,
+ _("文件夹 {old_} 重命名为 {new_}").format(
+ old_=old_.name, new_=new_.name
+ ),
+ )
+ return new_
+ return old_
+
+ def __scan_file(
+ self,
+ id_: str,
+ alias: str,
+ old_alias: str,
+ log,
+ ):
+ root = self.root.joinpath(f"{id_}_{alias}")
+ item_list = root.iterdir()
+ if self.folder_mode:
+ for f in item_list:
+ if f.is_dir():
+ f = self.__rename_works_folder(
+ f,
+ alias,
+ old_alias,
+ log,
+ )
+ files = f.iterdir()
+ self.__batch_rename(
+ f,
+ files,
+ alias,
+ old_alias,
+ log,
+ )
+ else:
+ self.__batch_rename(
+ root,
+ item_list,
+ alias,
+ old_alias,
+ log,
+ )
+
+ def __batch_rename(
+ self,
+ root: Path,
+ files,
+ alias: str,
+ old_alias: str,
+ log,
+ ):
+ for old_file in files:
+ if old_alias not in old_file.name:
+ break
+ self.__rename_file(
+ root,
+ old_file,
+ alias,
+ old_alias,
+ log,
+ )
+
+ def __rename_file(
+ self,
+ root: Path,
+ old_file: Path,
+ alias: str,
+ old_alias: str,
+ log,
+ ):
+ new_file = root.joinpath(old_file.name.replace(old_alias, alias, 1))
+ self.__rename(
+ old_file,
+ new_file,
+ _("文件"),
+ log,
+ )
+ logging(
+ log,
+ _("文件 {old_file} 重命名为 {new_file}").format(
+ old_file=old_file.name, new_file=new_file.name
+ ),
+ )
+ return True
+
+ @staticmethod
+ def __rename(
+ old_: Path,
+ new_: Path,
+ type_=_("文件"),
+ log=None,
+ ) -> bool:
+ try:
+ old_.rename(new_)
+ return True
+ except PermissionError as e:
+ logging(
+ log,
+ _("{type} {old}被占用,重命名失败: {error}").format(
+ type=type_, old=old_.name, error=e
+ ),
+ ERROR,
+ )
+ return False
+ except FileExistsError as e:
+ logging(
+ log,
+ _("{type} {new}名称重复,重命名失败: {error}").format(
+ type=type_, new=new_.name, error=e
+ ),
+ ERROR,
+ )
+ return False
+ except OSError as e:
+ logging(
+ log,
+ _("处理{type} {old}时发生预期之外的错误: {error}").format(
+ type=type_, old=old_.name, error=e
+ ),
+ ERROR,
+ )
+ return True
diff --git a/source/module/recorder.py b/source/module/recorder.py
index 3096e21..fc524f0 100644
--- a/source/module/recorder.py
+++ b/source/module/recorder.py
@@ -1,18 +1,17 @@
from asyncio import CancelledError
from contextlib import suppress
+from typing import TYPE_CHECKING
from aiosqlite import connect
-from ..module import Manager
+if TYPE_CHECKING:
+ from ..module import Manager
-__all__ = [
- "IDRecorder",
- "DataRecorder",
-]
+__all__ = ["IDRecorder", "DataRecorder", "MapRecorder"]
class IDRecorder:
- def __init__(self, manager: Manager):
+ def __init__(self, manager: "Manager"):
self.file = manager.root.joinpath("ExploreID.db")
self.switch = manager.download_record
self.database = None
@@ -31,7 +30,12 @@ class IDRecorder:
await self.cursor.execute("SELECT ID FROM explore_id WHERE ID=?", (id_,))
return await self.cursor.fetchone()
- async def add(self, id_: str) -> None:
+ async def add(
+ self,
+ id_: str,
+ *args,
+ **kwargs,
+ ) -> None:
if self.switch:
await self.database.execute("REPLACE INTO explore_id VALUES (?);", (id_,))
await self.database.commit()
@@ -82,7 +86,7 @@ class DataRecorder(IDRecorder):
("动图地址", "TEXT"),
)
- def __init__(self, manager: Manager):
+ def __init__(self, manager: "Manager"):
super().__init__(manager)
self.file = manager.folder.joinpath("ExploreData.db")
self.switch = manager.record_data
@@ -121,3 +125,54 @@ class DataRecorder(IDRecorder):
def __generate_values(self, data: dict) -> tuple:
return tuple(data[i] for i, _ in self.DATA_TABLE)
+
+
+class MapRecorder(IDRecorder):
+ def __init__(self, manager: "Manager"):
+ super().__init__(manager)
+ self.file = manager.root.joinpath("MappingData.db")
+ self.switch = manager.author_archive
+
+ async def _connect_database(self):
+ self.database = await connect(self.file)
+ self.cursor = await self.database.cursor()
+ await self.database.execute(
+ "CREATE TABLE IF NOT EXISTS mapping_data ("
+ "ID TEXT PRIMARY KEY,"
+ "NAME TEXT NOT NULL"
+ ");"
+ )
+ await self.database.commit()
+
+ async def select(self, id_: str):
+ if self.switch:
+ await self.cursor.execute(
+ "SELECT NAME FROM mapping_data WHERE ID=?", (id_,)
+ )
+ return await self.cursor.fetchone()
+
+ async def add(
+ self,
+ id_: str,
+ name: str,
+ ) -> None:
+ if self.switch:
+ await self.database.execute(
+ "REPLACE INTO mapping_data VALUES (?, ?);",
+ (
+ id_,
+ name,
+ ),
+ )
+ await self.database.commit()
+
+ async def __delete(self, id_: str) -> None:
+ pass
+
+ async def delete(self, ids: list[str]):
+ pass
+
+ async def all(self):
+ if self.switch:
+ await self.cursor.execute("SELECT ID, NAME FROM mapping_data")
+ return [i[0] for i in await self.cursor.fetchmany()]
diff --git a/source/module/settings.py b/source/module/settings.py
index 6aef1e2..0fbe735 100644
--- a/source/module/settings.py
+++ b/source/module/settings.py
@@ -1,16 +1,15 @@
-from json import dump
-from json import load
+from json import dump, load
from pathlib import Path
from platform import system
-from .static import ROOT
-from .static import USERAGENT
+from .static import ROOT, USERAGENT
__all__ = ["Settings"]
class Settings:
default = {
+ "mapping_data": {},
"work_path": "",
"folder_name": "Download",
"name_format": "发布时间 作者昵称 作品标题",
@@ -53,11 +52,11 @@ class Settings:
@classmethod
def check_keys(
- cls,
- data: dict,
- callback: callable,
- *args,
- **kwargs,
+ cls,
+ data: dict,
+ callback: callable,
+ *args,
+ **kwargs,
) -> dict:
needful_keys = set(cls.default.keys())
given_keys = set(data.keys())
diff --git a/source/module/static.py b/source/module/static.py
index 51bd92d..5b36bb8 100644
--- a/source/module/static.py
+++ b/source/module/static.py
@@ -6,7 +6,7 @@ VERSION_BETA = True
__VERSION__ = f"{VERSION_MAJOR}.{VERSION_MINOR}.{'beta' if VERSION_BETA else 'stable'}"
ROOT = Path(__file__).resolve().parent.parent.parent
PROJECT = f"XHS-Downloader V{VERSION_MAJOR}.{VERSION_MINOR} {
-'Beta' if VERSION_BETA else 'Stable'
+ 'Beta' if VERSION_BETA else 'Stable'
}"
REPOSITORY = "https://github.com/JoeanAmier/XHS-Downloader"
@@ -22,7 +22,7 @@ USERAGENT = (
HEADERS = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,"
- "application/signed-exchange;v=b3;q=0.7",
+ "application/signed-exchange;v=b3;q=0.7",
"referer": "https://www.xiaohongshu.com/explore",
"user-agent": USERAGENT,
}
diff --git a/source/module/tools.py b/source/module/tools.py
index 8a830fe..c2a25d0 100644
--- a/source/module/tools.py
+++ b/source/module/tools.py
@@ -4,6 +4,7 @@ from random import uniform
from rich import print
from rich.text import Text
+from ..translation import _
from .static import INFO
@@ -11,7 +12,7 @@ def retry(function):
async def inner(self, *args, **kwargs):
if result := await function(self, *args, **kwargs):
return result
- for _ in range(self.retry):
+ for __ in range(self.retry):
if result := await function(self, *args, **kwargs):
return result
return result
@@ -19,6 +20,23 @@ def retry(function):
return inner
+def retry_limited(function):
+ # TODO: 不支持 TUI
+ def inner(self, *args, **kwargs):
+ while True:
+ if function(self, *args, **kwargs):
+ return
+ if self.console.input(
+ _(
+ "如需重新尝试处理该对象,请关闭所有正在访问该对象的窗口或程序,然后直接按下回车键!\n"
+ "如需跳过处理该对象,请输入任意字符后按下回车键!"
+ ),
+ ):
+ return
+
+ return inner
+
+
def logging(log, text, style=INFO):
string = Text(text, style=style)
if log:
@@ -31,7 +49,7 @@ def logging(log, text, style=INFO):
async def sleep_time(
- min_time: int | float = 0.5,
- max_time: int | float = 1.5,
+ min_time: int | float = 1.0,
+ max_time: int | float = 2.5,
):
await sleep(uniform(min_time, max_time))
diff --git a/static/Release_Notes.md b/static/Release_Notes.md
index c71efd4..1b9fcce 100644
--- a/static/Release_Notes.md
+++ b/static/Release_Notes.md
@@ -4,12 +4,14 @@
2. 新增启动 `监听剪贴板` 模式时清空剪贴板内容
3. 修复 `监听剪贴板` 模式可能丢失链接的问题
4. 支持按作者归档保存作品文件
-5. 优化 `headers` 处理逻辑
-6. 支持 `SOCKS` 代理
+5. 新增自动更新作者昵称功能
+6. 优化 `headers` 处理逻辑
+7. 支持 `SOCKS` 代理
+8. 支持设置作者别名
**注意:**
-配置文件新增参数 author_archive,旧版本更新需要手动添加配置内容:"author_archive": false;或者直接删除旧版配置文件后再运行程序!
+配置文件新增参数 mapping_data、author_archive,旧版本更新需要手动添加配置内容:"mapping_data": {}, "author_archive": false;或者直接删除旧版配置文件后再运行程序!
*****