feat: 支持设置作者别名

1. 新增 mapping_data 配置参数
2. 新增自动更新作者昵称功能

Closes #176
Closes #194
Closes #199
Closes #229
This commit is contained in:
2025-03-22 22:10:17 +08:00
parent f332b3fb2d
commit 3b4f23c670
25 changed files with 618 additions and 307 deletions

View File

@@ -1,9 +1,4 @@
from asyncio import Event
from asyncio import Queue
from asyncio import QueueEmpty
from asyncio import create_task
from asyncio import gather
from asyncio import sleep
from asyncio import Event, Queue, QueueEmpty, create_task, gather, sleep
from contextlib import suppress
from datetime import datetime
from re import compile
@@ -11,37 +6,40 @@ from urllib.parse import urlparse
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from pyperclip import copy
# from aiohttp import web
from pyperclip import paste
from uvicorn import Config
from uvicorn import Server
from pyperclip import copy, paste
from uvicorn import Config, Server
from source.expansion import BrowserCookie
from source.expansion import Cleaner
from source.expansion import Converter
from source.expansion import Namespace
from source.expansion import beautify_string
from source.module import DataRecorder
from source.module import ExtractData
from source.module import ExtractParams
from source.module import IDRecorder
from source.module import Manager
from source.expansion import (
BrowserCookie,
Cleaner,
Converter,
Namespace,
beautify_string,
)
from source.module import (
ROOT,
__VERSION__,
ERROR,
WARNING,
MASTER,
REPOSITORY,
ROOT,
VERSION_BETA,
VERSION_MAJOR,
VERSION_MINOR,
VERSION_BETA,
__VERSION__,
WARNING,
DataRecorder,
ExtractData,
ExtractParams,
IDRecorder,
Manager,
MapRecorder,
logging,
sleep_time,
)
from source.module import logging
from source.module import sleep_time
from source.translation import switch_language, _
from source.translation import _, switch_language
from ..module import Mapping
from .download import Download
from .explore import Explore
from .image import Image
@@ -87,6 +85,7 @@ class XHS:
def __init__(
self,
mapping_data: dict = None,
work_path="",
folder_name="Download",
name_format="发布时间 作者昵称 作品标题",
@@ -132,6 +131,11 @@ class XHS:
author_archive,
_print,
)
self.mapping_data = mapping_data or {}
self.map_recorder = MapRecorder(
self.manager,
)
self.mapping = Mapping(self.manager, self.map_recorder)
self.html = Html(self.manager)
self.image = Image()
self.video = Video()
@@ -309,11 +313,29 @@ class XHS:
self.__extract_image(data, namespace)
else:
data["下载地址"] = []
await self.update_author_nickname(data, log)
await self.__download_files(data, download, index, log, bar)
logging(log, _("作品处理完成:{0}").format(i))
await sleep_time()
return data
async def update_author_nickname(
self,
container: dict,
log,
):
if a := self.CLEANER.filter_name(
self.mapping_data.get(i := container["作者ID"], "")
):
container["作者昵称"] = a
else:
container["作者昵称"] = self.manager.filter_name(container["作者昵称"]) or i
await self.mapping.update_cache(
i,
container["作者昵称"],
log,
)
@staticmethod
def __extract_link_id(url: str) -> str:
link = urlparse(url)
@@ -330,8 +352,6 @@ class XHS:
match key:
case "发布时间":
values.append(self.__get_name_time(data))
case "作者昵称":
values.append(self.__get_name_author(data))
case "作品标题":
values.append(self.__get_name_title(data))
case _:
@@ -353,9 +373,6 @@ class XHS:
def __get_name_time(data: dict) -> str:
return data["发布时间"].replace(":", ".")
def __get_name_author(self, data: dict) -> str:
return self.manager.filter_name(data["作者昵称"]) or data["作者ID"]
def __get_name_title(self, data: dict) -> str:
return (
beautify_string(
@@ -419,11 +436,13 @@ class XHS:
async def __aenter__(self):
await self.id_recorder.__aenter__()
await self.data_recorder.__aenter__()
await self.map_recorder.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.id_recorder.__aexit__(exc_type, exc_value, traceback)
await self.data_recorder.__aexit__(exc_type, exc_value, traceback)
await self.map_recorder.__aexit__(exc_type, exc_value, traceback)
await self.close()
async def close(self):

View File

@@ -1,5 +1,4 @@
from asyncio import Semaphore
from asyncio import gather
from asyncio import Semaphore, gather
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -7,22 +6,24 @@ from aiofiles import open
from httpx import HTTPError
from ..expansion import CacheError
from ..module import ERROR
from ..module import (
FILE_SIGNATURES_LENGTH,
FILE_SIGNATURES,
)
from ..module import MAX_WORKERS
# from ..module import WARNING
from ..module import Manager
from ..module import logging
from ..module import (
ERROR,
FILE_SIGNATURES,
FILE_SIGNATURES_LENGTH,
MAX_WORKERS,
logging,
sleep_time,
)
from ..module import retry as re_download
from ..module import sleep_time
from ..translation import _
if TYPE_CHECKING:
from httpx import AsyncClient
from ..module import Manager
__all__ = ["Download"]
@@ -39,8 +40,8 @@ class Download:
}
def __init__(
self,
manager: Manager,
self,
manager: "Manager",
):
self.manager = manager
self.folder = manager.folder
@@ -66,15 +67,15 @@ class Download:
self.author_archive = manager.author_archive
async def run(
self,
urls: list,
lives: list,
index: list | tuple | None,
nickname: str,
filename: str,
type_: str,
log,
bar,
self,
urls: list,
lives: list,
index: list | tuple | None,
nickname: str,
filename: str,
type_: str,
log,
bar,
) -> tuple[Path, list[Any]]:
path = self.__generate_path(nickname, filename)
if type_ == _("视频"):
@@ -109,7 +110,7 @@ class Download:
tasks = await gather(*tasks)
return path, tasks
def __generate_path(self, nickname:str, filename: str):
def __generate_path(self, nickname: str, filename: str):
if self.author_archive:
folder = self.folder.joinpath(nickname)
folder.mkdir(exist_ok=True)
@@ -120,7 +121,7 @@ class Download:
return path
def __ready_download_video(
self, urls: list[str], path: Path, name: str, log
self, urls: list[str], path: Path, name: str, log
) -> list:
if not self.video_download:
logging(log, _("视频作品下载功能已关闭,跳过下载"))
@@ -130,13 +131,13 @@ class Download:
return [(urls[0], name, self.video_format)]
def __ready_download_image(
self,
urls: list[str],
lives: list[str],
index: list | tuple | None,
path: Path,
name: str,
log,
self,
urls: list[str],
lives: list[str],
index: list | tuple | None,
path: Path,
name: str,
log,
) -> list:
tasks = []
if not self.image_download:
@@ -147,32 +148,32 @@ class Download:
continue
file = f"{name}_{i}"
if not any(
self.__check_exists_path(
path,
f"{file}.{s}",
log,
)
for s in self.image_format_list
self.__check_exists_path(
path,
f"{file}.{s}",
log,
)
for s in self.image_format_list
):
tasks.append([j[0], file, self.image_format])
if (
not self.live_download
or not j[1]
or self.__check_exists_path(
not self.live_download
or not j[1]
or self.__check_exists_path(
path,
f"{file}.{self.live_format}",
log,
)
)
):
continue
tasks.append([j[1], file, self.live_format])
return tasks
def __check_exists_glob(
self,
path: Path,
name: str,
log,
self,
path: Path,
name: str,
log,
) -> bool:
if any(path.glob(name)):
logging(log, _("{0} 文件已存在,跳过下载").format(name))
@@ -180,10 +181,10 @@ class Download:
return False
def __check_exists_path(
self,
path: Path,
name: str,
log,
self,
path: Path,
name: str,
log,
) -> bool:
if path.joinpath(name).exists():
logging(log, _("{0} 文件已存在,跳过下载").format(name))
@@ -192,13 +193,13 @@ class Download:
@re_download
async def __download(
self,
url: str,
path: Path,
name: str,
format_: str,
log,
bar,
self,
url: str,
path: Path,
name: str,
format_: str,
log,
bar,
):
async with self.SEMAPHORE:
headers = self.headers.copy()
@@ -224,9 +225,9 @@ class Download:
)
try:
async with self.client.stream(
"GET",
url,
headers=headers,
"GET",
url,
headers=headers,
) as response:
await sleep_time()
if response.status_code == 416:
@@ -276,9 +277,9 @@ class Download:
@staticmethod
def __create_progress(
bar,
total: int | None,
completed=0,
bar,
total: int | None,
completed=0,
):
if bar:
bar.update(total=total, completed=completed)
@@ -293,10 +294,10 @@ class Download:
return cls.CONTENT_TYPE_MAP.get(content, "")
async def __head_file(
self,
url: str,
headers: dict[str, str],
suffix: str,
self,
url: str,
headers: dict[str, str],
suffix: str,
) -> tuple[int, str]:
response = await self.client.head(
url,
@@ -313,26 +314,26 @@ class Download:
return file.stat().st_size if file.is_file() else 0
def __update_headers_range(
self,
headers: dict[str, str],
file: Path,
self,
headers: dict[str, str],
file: Path,
) -> int:
headers["Range"] = f"bytes={(p := self.__get_resume_byte_position(file))}-"
return p
async def __suffix_with_file(
self,
temp: Path,
path: Path,
name: str,
default_suffix: str,
log,
self,
temp: Path,
path: Path,
name: str,
default_suffix: str,
log,
) -> Path:
try:
async with open(temp, "rb") as f:
file_start = await f.read(FILE_SIGNATURES_LENGTH)
for offset, signature, suffix in FILE_SIGNATURES:
if file_start[offset: offset + len(signature)] == signature:
if file_start[offset : offset + len(signature)] == signature:
return path.joinpath(f"{name}.{suffix}")
except Exception as error:
logging(

View File

@@ -1,4 +1,5 @@
from source.expansion import Namespace
from .request import Html
__all__ = ["Image"]
@@ -37,8 +38,8 @@ class Image:
@staticmethod
def __generate_fixed_link(
token: str,
format_: str,
token: str,
format_: str,
) -> str:
return f"https://ci.xiaohongshu.com/{token}?imageView2/format/{format_}"
@@ -50,10 +51,10 @@ class Image:
def __get_live_link(items: list) -> list:
return [
(
Html.format_url(
Namespace.object_extract(item, "stream.h264[0].masterUrl")
)
or None
Html.format_url(
Namespace.object_extract(item, "stream.h264[0].masterUrl")
)
or None
)
for item in items
]

View File

@@ -1,19 +1,20 @@
from typing import TYPE_CHECKING
from httpx import HTTPError
from ..module import ERROR
from ..module import Manager
from ..module import logging
from ..module import retry
from ..module import sleep_time
from ..module import ERROR, Manager, logging, retry, sleep_time
from ..translation import _
if TYPE_CHECKING:
from ..module import Manager
__all__ = ["Html"]
class Html:
def __init__(
self,
manager: Manager,
self,
manager: "Manager",
):
self.retry = manager.retry
self.client = manager.request_client
@@ -21,12 +22,12 @@ class Html:
@retry
async def request_url(
self,
url: str,
content=True,
log=None,
cookie: str = None,
**kwargs,
self,
url: str,
content=True,
log=None,
cookie: str = None,
**kwargs,
) -> str:
headers = self.update_cookie(
cookie,
@@ -63,16 +64,16 @@ class Html:
return bytes(url, "utf-8").decode("unicode_escape")
def update_cookie(
self,
cookie: str = None,
self,
cookie: str = None,
) -> dict:
return self.headers | {"Cookie": cookie} if cookie else self.headers.copy()
async def __request_url_head(
self,
url: str,
headers: dict,
**kwargs,
self,
url: str,
headers: dict,
**kwargs,
):
return await self.client.head(
url,
@@ -81,10 +82,10 @@ class Html:
)
async def __request_url_get(
self,
url: str,
headers: dict,
**kwargs,
self,
url: str,
headers: dict,
**kwargs,
):
return await self.client.get(
url,