Source code for scratchattach.utils.requests

from __future__ import annotations

from collections.abc import MutableMapping, Iterator
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Optional, Any, Self, Union
from contextlib import contextmanager
from enum import Enum, auto
from dataclasses import dataclass, field
import json

from aiohttp.cookiejar import DummyCookieJar
from typing_extensions import override
from requests import Session as HTTPSession
from requests import Response
import aiohttp

from . import exceptions
from . import optional_async

proxies: Optional[MutableMapping[str, str]] = None

[docs] class HTTPMethod(Enum): GET = auto() POST = auto() PUT = auto() DELETE = auto() HEAD = auto() OPTIONS = auto() PATCH = auto() TRACE = auto()
[docs] @classmethod def of(cls, name: str) -> HTTPMethod: member_map = { "GET": cls.GET, "POST": cls.POST, "PUT": cls.PUT, "DELETE": cls.DELETE, "HEAD": cls.HEAD, "OPTIONS": cls.OPTIONS, "PATCH": cls.PATCH, "TRACE": cls.TRACE } return member_map[name]
[docs] class AnyHTTPResponse(ABC): request_method: HTTPMethod status_code: int content: bytes text: str headers: dict[str, str]
[docs] def json(self) -> Any: return json.loads(self.text)
[docs] @dataclass class HTTPResponse(AnyHTTPResponse): request_method: HTTPMethod = field(kw_only=True) status_code: int = field(kw_only=True) content: bytes = field(kw_only=True) text: str = field(kw_only=True) headers: dict[str, str] = field(kw_only=True)
[docs] class OAHTTPSession(ABC): error_handling: bool = True
[docs] @abstractmethod def sync_request( self, method: HTTPMethod, url: str, *, cookies: Optional[dict[str, str]] = None, headers: Optional[dict[str, str]] = None, params: Optional[dict[str, str]] = None, data: Optional[Union[dict[str, str], str]] = None, json: Optional[Any] = None ) -> AnyHTTPResponse: pass
[docs] @abstractmethod async def async_request( self, method: HTTPMethod, url: str, *, cookies: Optional[dict[str, str]] = None, headers: Optional[dict[str, str]] = None, params: Optional[dict[str, str]] = None, data: Optional[Union[dict[str, str], str]] = None, json: Optional[Any] = None ) -> AnyHTTPResponse: pass
[docs] def check_response(self, r: AnyHTTPResponse): if r.status_code == 403 or r.status_code == 401: raise exceptions.Unauthorized(f"Request content: {r.content!r}") if r.status_code == 500: raise exceptions.APIError("Internal Scratch server error") if r.status_code == 429: raise exceptions.Response429("You are being rate-limited (or blocked) by Scratch") if r.json() == {"code":"BadRequest","message":""}: raise exceptions.BadRequest("Make sure all provided arguments are valid")
[docs] def request( self, method: Union[HTTPMethod, str], url: str, *, cookies: Optional[dict[str, str]] = None, headers: Optional[dict[str, str]] = None, params: Optional[dict[str, str]] = None, data: Optional[Union[dict[str, str], str]] = None, json: Optional[Any] = None ) -> optional_async.CARequest: if isinstance(method, str): method = HTTPMethod.of(method.upper()) return optional_async.CARequest( self, method, url, cookies = cookies, headers = headers, params = params, data = data, json = json )
[docs] @contextmanager def no_error_handling(self) -> Iterator[None]: val_before = self.error_handling self.error_handling = False try: yield finally: self.error_handling = val_before
[docs] @contextmanager def yes_error_handling(self) -> Iterator[None]: val_before = self.error_handling self.error_handling = True try: yield finally: self.error_handling = val_before
[docs] class SyncRequests(OAHTTPSession):
[docs] @override def sync_request(self, method, url, *, cookies = None, headers = None, params = None, data = None, json = None): try: r = requests.request( method.name, url, cookies = cookies, headers = headers, params = params, data = data, json = json, proxies = proxies ) except Exception as e: raise exceptions.FetchError(e) response = HTTPResponse( request_method=method, status_code=r.status_code, content=r.content, text=r.text, headers=r.headers ) if self.error_handling: self.check_response(response) return response
[docs] async def async_request(self, method, url, *, cookies = None, headers = None, params = None, data = None, json = None): raise NotImplementedError()
[docs] class AsyncRequests(OAHTTPSession): client_session: aiohttp.ClientSession async def __aenter__(self) -> Self: self.client_session = await aiohttp.ClientSession(cookie_jar=DummyCookieJar()).__aenter__() return self async def __aexit__( self, exc_type: Optional[type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional[TracebackType] = None ) -> None: await self.client_session.__aexit__(exc_type, exc_val, exc_tb)
[docs] @override def sync_request(self, method, url, *, cookies = None, headers = None, params = None, data = None, json = None): raise NotImplementedError()
[docs] async def async_request(self, method, url, *, cookies = None, headers = None, params = None, data = None, json = None): proxy = None if url.startswith("http"): proxy = proxies.get("http") if url.startswith("https"): proxy = proxies.get("https") async with self.client_session.request( method.name, url, cookies = cookies, headers = headers, params = params, data = data, json = json, proxy = proxy ) as resp: assert isinstance(resp, aiohttp.ClientResponse) content = await resp.read() try: text = content.decode(resp.get_encoding()) except Exception: text = "" response = HTTPResponse( request_method=method, status_code=resp.status, content=content, text=text, headers=resp.headers ) if self.error_handling: self.check_response(response) return response
[docs] class Requests(HTTPSession): """ Centralized HTTP request handler (for better error handling and proxies) """ error_handling: bool = True
[docs] def check_response(self, r: Response): if r.status_code == 403 or r.status_code == 401: raise exceptions.Unauthorized(f"Request content: {r.content!r}") if r.status_code == 500: raise exceptions.APIError("Internal Scratch server error") if r.status_code == 429: raise exceptions.Response429("You are being rate-limited (or blocked) by Scratch") if r.json() == {"code":"BadRequest","message":""}: raise exceptions.BadRequest("Make sure all provided arguments are valid")
[docs] @override def get(self, *args, **kwargs): kwargs.setdefault("proxies", proxies) try: r = super().get(*args, **kwargs) except Exception as e: raise exceptions.FetchError(e) if self.error_handling: self.check_response(r) return r
[docs] @override def post(self, *args, **kwargs): kwargs.setdefault("proxies", proxies) try: r = super().post(*args, **kwargs) except Exception as e: raise exceptions.FetchError(e) if self.error_handling: self.check_response(r) return r
[docs] @override def delete(self, *args, **kwargs): kwargs.setdefault("proxies", proxies) try: r = super().delete(*args, **kwargs) except Exception as e: raise exceptions.FetchError(e) if self.error_handling: self.check_response(r) return r
[docs] @override def put(self, *args, **kwargs): kwargs.setdefault("proxies", proxies) try: r = super().put(*args, **kwargs) except Exception as e: raise exceptions.FetchError(e) if self.error_handling: self.check_response(r) return r
[docs] @contextmanager def no_error_handling(self) -> Iterator[None]: val_before = self.error_handling self.error_handling = False try: yield finally: self.error_handling = val_before
[docs] @contextmanager def yes_error_handling(self) -> Iterator[None]: val_before = self.error_handling self.error_handling = True try: yield finally: self.error_handling = val_before
requests = Requests()