From 395d81c52228a8c82f05a5180ee6d5c8bc521447 Mon Sep 17 00:00:00 2001 From: JoeamAmier Date: Tue, 5 Dec 2023 22:48:21 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=A1=B9=E7=9B=AE=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- main.py | 35 ++++++++++---------- requirements.txt | 2 +- source/Downloader.py | 59 +++++++++++++++++++++------------ source/Html.py | 46 ++++++++++++-------------- source/Manager.py | 7 +--- source/__init__.py | 78 ++++++++++++++++++++++++++++---------------- 7 files changed, 129 insertions(+), 100 deletions(-) diff --git a/README.md b/README.md index 7dfc1cb..b01126d 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ with XHS(path=path, Download -proxies +proxy str 设置代理 无 diff --git a/main.py b/main.py index 813051e..46bad1f 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,9 @@ +from asyncio import run + from source import XHS -from source import XHSDownloader -def example(): +async def example(): """通过代码设置参数,适合二次开发""" # 测试链接 error_demo = "https://github.com/JoeanAmier/XHS_Downloader" @@ -15,22 +16,22 @@ def example(): proxies = None # 网络代理 timeout = 5 # 网络请求超时限制,默认值:10 chunk = 1024 * 1024 # 下载文件时,每次从服务器获取的数据块大小,单位字节 - # with XHS() as xhs: - # pass # 使用默认参数 - with XHS(path=path, - folder=folder, - proxies=proxies, - timeout=timeout, - chunk=chunk) as xhs: # 使用自定义参数 - download = True # 是否下载作品文件,默认值:False + async with XHS() as xhs: + pass # 使用默认参数 + async with XHS(path=path, + folder=folder, + proxy=proxies, + timeout=timeout, + chunk=chunk) as xhs: # 使用自定义参数 + download = False # 是否下载作品文件,默认值:False # 返回作品详细信息,包括下载地址 - print(xhs.extract(error_demo)) # 获取数据失败时返回空字典 - print(xhs.extract(image_demo, download=download)) - print(xhs.extract(video_demo, download=download)) - print(xhs.extract(multiple_demo, download=download)) + print(await xhs.extract(error_demo)) # 获取数据失败时返回空字典 + print(await xhs.extract(image_demo, download=download)) + print(await xhs.extract(video_demo, download=download)) + print(await xhs.extract(multiple_demo, download=download)) if __name__ == '__main__': - # example() - with XHSDownloader() as xhs: - xhs.run() + run(example()) + # with XHSDownloader() as xhs: + # xhs.run() diff --git a/requirements.txt b/requirements.txt index 016d367..8263405 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -requests>=2.31.0 +aiohttp>=3.9.0 textual>=0.40.0 diff --git a/source/Downloader.py b/source/Downloader.py index 2976a61..c69b31d 100644 --- a/source/Downloader.py +++ b/source/Downloader.py @@ -1,7 +1,11 @@ from pathlib import Path -from requests import exceptions -from requests import get +from aiohttp import ClientConnectionError +from aiohttp import ClientProxyConnectionError +from aiohttp import ClientSSLError +from aiohttp import ClientSession + +# from aiohttp import ClientTimeout __all__ = ['Download'] @@ -14,20 +18,16 @@ class Download: root: Path, path: str, folder: str, - proxies=None, + proxy: str = None, chunk=1024 * 1024, timeout=10, ): self.manager = manager self.temp = manager.temp - self.headers = manager.headers self.root = self.__init_root(root, path, folder) - self.proxies = { - "http": proxies, - "https": proxies, - "ftp": proxies, - } + self.proxy = proxy self.chunk = chunk - self.timeout = timeout + # self.timeout = ClientTimeout(total=timeout) + self.session = ClientSession(headers=manager.headers) def __init_root(self, root: Path, path: str, folder: str) -> Path: if path and (r := Path(path)).is_dir(): @@ -38,29 +38,46 @@ class Download: self.temp.mkdir(exist_ok=True) return root - def run(self, urls: list, name: str, type_: int): + async def run(self, urls: list, name: str, type_: int, log, bar): if type_ == 0: - self.__download(urls[0], f"{name}.mp4") + await self.__download(urls[0], f"{name}.mp4", log, bar) elif type_ == 1: for index, url in enumerate(urls): - self.__download(url, f"{name}_{index + 1}.png") + await self.__download(url, f"{name}_{index + 1}.png", log, bar) - def __download(self, url: str, name: str): + async def __download(self, url: str, name: str, log, bar): temp = self.temp.joinpath(name) file = self.root.joinpath(name) if self.manager.is_exists(file): return try: - with get(url, headers=self.headers, proxies=self.proxies, stream=True, timeout=self.timeout) as response: + async with self.session.get(url, proxy=self.proxy) as response: + # self.__create_progress(bar, int(response.headers.get('content-length', 0))) with temp.open("wb") as f: - for chunk in response.iter_content(chunk_size=self.chunk): + async for chunk in response.content.iter_chunked(self.chunk): f.write(chunk) + # self.__update_progress(bar, len(chunk)) + # self.__remove_progress(bar) self.manager.move(temp, file) except ( - exceptions.ProxyError, - exceptions.SSLError, - exceptions.ChunkedEncodingError, - exceptions.ConnectionError, - exceptions.ReadTimeout, + ClientProxyConnectionError, + ClientSSLError, + ClientConnectionError, + TimeoutError, ): self.manager.delete(temp) + # self.__remove_progress(bar) + + # @staticmethod + # def __create_progress(bar, total: int | None): + # if bar: + # bar.update(total=total) + # + # @staticmethod + # def __update_progress(bar, advance: int): + # if bar: + # bar.advance(advance) + # + # @staticmethod + # def __remove_progress(bar): + # pass diff --git a/source/Html.py b/source/Html.py index aad8339..8be1848 100644 --- a/source/Html.py +++ b/source/Html.py @@ -1,5 +1,9 @@ -from requests import exceptions -from requests import get +from aiohttp import ClientConnectionError +from aiohttp import ClientProxyConnectionError +from aiohttp import ClientSSLError +from aiohttp import ClientSession + +# from aiohttp import ClientTimeout __all__ = ['Html'] @@ -9,39 +13,29 @@ class Html: def __init__( self, headers: dict, - proxies=None, + proxy: str = None, timeout=10, ): - self.headers = headers | {"Referer": "https://www.xiaohongshu.com/", } - self.proxies = { - "http": proxies, - "https": proxies, - "ftp": proxies, - } - self.timeout = timeout + self.proxy = proxy + self.session = ClientSession( + headers=headers | { + "Referer": "https://www.xiaohongshu.com/", }) - def request_url( + async def request_url( self, url: str, - params=None, - headers=None, text=True, ) -> str: try: - response = get( - url, - params=params, - proxies=self.proxies, - timeout=self.timeout, - headers=headers or self.headers, ) + async with self.session.get( + url, + proxy=self.proxy, + ) as response: + return await response.text() if text else response.url except ( - exceptions.ProxyError, - exceptions.SSLError, - exceptions.ChunkedEncodingError, - exceptions.ConnectionError, - exceptions.ReadTimeout, + ClientProxyConnectionError, + ClientSSLError, + ClientConnectionError, ): - print("网络异常,获取网页源码失败!") return "" - return response.text if text else response.url @staticmethod def format_url(url: str) -> str: diff --git a/source/Manager.py b/source/Manager.py index 2c263c8..502d921 100644 --- a/source/Manager.py +++ b/source/Manager.py @@ -2,7 +2,7 @@ from pathlib import Path from shutil import move from shutil import rmtree -__all__ = ['Manager', "rich_log"] +__all__ = ["Manager"] class Manager: @@ -27,8 +27,3 @@ class Manager: def clean(self): rmtree(self.temp.resolve()) - - -def rich_log(log, text): - if log: - log.write(text) diff --git a/source/__init__.py b/source/__init__.py index f3741d2..b69089d 100644 --- a/source/__init__.py +++ b/source/__init__.py @@ -2,9 +2,11 @@ from pathlib import Path from re import compile from pyperclip import paste +from rich.text import Text from textual.app import App from textual.app import ComposeResult from textual.binding import Binding +from textual.containers import Center from textual.containers import HorizontalScroll from textual.containers import ScrollableContainer from textual.widgets import Button @@ -12,6 +14,7 @@ from textual.widgets import Footer from textual.widgets import Header from textual.widgets import Input from textual.widgets import Label +from textual.widgets import ProgressBar from textual.widgets import RichLog from .Downloader import Download @@ -19,7 +22,6 @@ from .Explore import Explore from .Html import Html from .Image import Image from .Manager import Manager -from .Manager import rich_log from .Settings import Settings from .Video import Video @@ -36,13 +38,13 @@ class XHS: self, path="", folder="Download", - proxies=None, + proxy=None, timeout=10, chunk=1024 * 1024, **kwargs, ): self.manager = Manager(self.ROOT) - self.html = Html(self.manager.headers, proxies, timeout) + self.html = Html(self.manager.headers, proxy, timeout) self.image = Image() self.video = Video() self.explore = Explore() @@ -51,55 +53,56 @@ class XHS: self.ROOT, path, folder, - proxies, + proxy, chunk, timeout) - def __get_image(self, container: dict, html: str, download, log): + async def __get_image(self, container: dict, html: str, download, log, bar): urls = self.image.get_image_link(html) - # rich_log(log, urls) # 调试代码 + # self.rich_log(log, urls) # 调试代码 if download: - self.download.run(urls, self.__naming_rules(container), 1) + await self.download.run(urls, self.__naming_rules(container), 1, log, bar) container["下载地址"] = urls - def __get_video(self, container: dict, html: str, download, log): + async def __get_video(self, container: dict, html: str, download, log, bar): url = self.video.get_video_link(html) - # rich_log(log, url) # 调试代码 + # self.rich_log(log, url) # 调试代码 if download: - self.download.run(url, self.__naming_rules(container), 0) + await self.download.run(url, self.__naming_rules(container), 0, log, bar) container["下载地址"] = url - def extract(self, url: str, download=False, log=None) -> list[dict]: - urls = self.__deal_links(url) - # rich_log(log, urls) # 调试代码 + async def extract(self, url: str, download=False, log=None, bar=None) -> list[dict]: + # return # 调试代码 + urls = await self.__deal_links(url) + # self.rich_log(log, urls) # 调试代码 # return urls # 调试代码 - return [self.__deal_extract(i, download, log) for i in urls] + return [await self.__deal_extract(i, download, log, bar) for i in urls] - def __deal_links(self, url: str) -> list: + async def __deal_links(self, url: str) -> list: urls = [] for i in url.split(): if u := self.short.search(i): - i = self.html.request_url( - u.group(), headers=self.manager.headers, text=False) + i = await self.html.request_url( + u.group(), False) if u := self.share.search(i): urls.append(u.group()) elif u := self.link.search(i): urls.append(u.group()) return urls - def __deal_extract(self, url: str, download: bool, log): - html = self.html.request_url(url) - # rich_log(log, html) # 调试代码 + async def __deal_extract(self, url: str, download: bool, log, bar): + html = await self.html.request_url(url) + # self.rich_log(log, html) # 调试代码 if not html: return {} data = self.explore.run(html) - # rich_log(log, data) # 调试代码 + # self.rich_log(log, data) # 调试代码 if not data: return {} if data["作品类型"] == "视频": - self.__get_video(data, html, download, log) + await self.__get_video(data, html, download, log, bar) else: - self.__get_image(data, html, download, log) + await self.__get_image(data, html, download, log, bar) return data @staticmethod @@ -107,11 +110,19 @@ class XHS: """下载文件默认使用作品 ID 作为文件名,可修改此方法自定义文件名格式""" return data["作品ID"] - def __enter__(self): + async def __aenter__(self): return self - def __exit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type, exc_value, traceback): self.manager.clean() + await self.html.session.close() + await self.download.session.close() + + def rich_log(self, log, text, style="b bright_green"): + if log: + log.write(Text(text, style=style)) + else: + self.console.print(text, style=style) class XHSDownloader(App): @@ -126,6 +137,9 @@ class XHSDownloader(App): ("d", "toggle_dark", "切换主题"), ] + def __init__(self): + super().__init__() + def __enter__(self): return self @@ -138,7 +152,10 @@ class XHSDownloader(App): Input(placeholder="多个链接之间使用空格分隔"), HorizontalScroll(Button("下载无水印图片/视频", id="deal"), Button("读取剪贴板", id="paste"), - Button("清空输入框", id="reset"), )) + Button("清空输入框", id="reset"), ), + ) + with Center(): + yield ProgressBar(total=None) yield RichLog(markup=True) yield Footer() @@ -158,6 +175,11 @@ class XHSDownloader(App): def deal(self): url = self.query_one(Input) log = self.query_one(RichLog) - if self.APP.extract(url.value, True, log=log): - pass + bar = self.query_one(ProgressBar) + if not url.value: + log.write(Text("未输入任何小红书作品链接!", style="yellow")) + return + _ = self.APP.extract(url.value, True, log=log, bar=bar) + if not _: + log.write(Text("获取小红书作品数据失败!", style="red")) url.value = ""