from collections import deque from typing import Deque, Dict, List, Optional, Tuple from twisted.internet import defer from twisted.internet.base import ReactorBase from twisted.internet.defer import Deferred from twisted.internet.endpoints import HostnameEndpoint from twisted.python.failure import Failure from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _StandardEndpointFactory from twisted.web.error import SchemeNotSupported from scrapy.core.downloader.contextfactory import AcceptableProtocolsContextFactory from scrapy.core.http2.protocol import H2ClientProtocol, H2ClientFactory from scrapy.http.request import Request from scrapy.settings import Settings from scrapy.spiders import Spider class H2ConnectionPool: def __init__(self, reactor: ReactorBase, settings: Settings) -> None: self._reactor = reactor self.settings = settings # Store a dictionary which is used to get the respective # H2ClientProtocolInstance using the key as Tuple(scheme, hostname, port) self._connections: Dict[Tuple, H2ClientProtocol] = {} # Save all requests that arrive before the connection is established self._pending_requests: Dict[Tuple, Deque[Deferred]] = {} def get_connection(self, key: Tuple, uri: URI, endpoint: HostnameEndpoint) -> Deferred: if key in self._pending_requests: # Received a request while connecting to remote # Create a deferred which will fire with the H2ClientProtocol # instance d = Deferred() self._pending_requests[key].append(d) return d # Check if we already have a connection to the remote conn = self._connections.get(key, None) if conn: # Return this connection instance wrapped inside a deferred return defer.succeed(conn) # No connection is established for the given URI return self._new_connection(key, uri, endpoint) def _new_connection(self, key: Tuple, uri: URI, endpoint: HostnameEndpoint) -> Deferred: self._pending_requests[key] = deque() conn_lost_deferred = Deferred() conn_lost_deferred.addCallback(self._remove_connection, key) factory = H2ClientFactory(uri, self.settings, conn_lost_deferred) conn_d = endpoint.connect(factory) conn_d.addCallback(self.put_connection, key) d = Deferred() self._pending_requests[key].append(d) return d def put_connection(self, conn: H2ClientProtocol, key: Tuple) -> H2ClientProtocol: self._connections[key] = conn # Now as we have established a proper HTTP/2 connection # we fire all the deferred's with the connection instance pending_requests = self._pending_requests.pop(key, None) while pending_requests: d = pending_requests.popleft() d.callback(conn) return conn def _remove_connection(self, errors: List[BaseException], key: Tuple) -> None: self._connections.pop(key) # Call the errback of all the pending requests for this connection pending_requests = self._pending_requests.pop(key, None) while pending_requests: d = pending_requests.popleft() d.errback(errors) def close_connections(self) -> None: """Close all the HTTP/2 connections and remove them from pool Returns: Deferred that fires when all connections have been closed """ for conn in self._connections.values(): conn.transport.abortConnection() class H2Agent: def __init__( self, reactor: ReactorBase, pool: H2ConnectionPool, context_factory: BrowserLikePolicyForHTTPS = BrowserLikePolicyForHTTPS(), connect_timeout: Optional[float] = None, bind_address: Optional[bytes] = None, ) -> None: self._reactor = reactor self._pool = pool self._context_factory = AcceptableProtocolsContextFactory(context_factory, acceptable_protocols=[b'h2']) self.endpoint_factory = _StandardEndpointFactory( self._reactor, self._context_factory, connect_timeout, bind_address ) def get_endpoint(self, uri: URI): return self.endpoint_factory.endpointForURI(uri) def get_key(self, uri: URI) -> Tuple: """ Arguments: uri - URI obtained directly from request URL """ return uri.scheme, uri.host, uri.port def request(self, request: Request, spider: Spider) -> Deferred: uri = URI.fromBytes(bytes(request.url, encoding='utf-8')) try: endpoint = self.get_endpoint(uri) except SchemeNotSupported: return defer.fail(Failure()) key = self.get_key(uri) d = self._pool.get_connection(key, uri, endpoint) d.addCallback(lambda conn: conn.request(request, spider)) return d class ScrapyProxyH2Agent(H2Agent): def __init__( self, reactor: ReactorBase, proxy_uri: URI, pool: H2ConnectionPool, context_factory: BrowserLikePolicyForHTTPS = BrowserLikePolicyForHTTPS(), connect_timeout: Optional[float] = None, bind_address: Optional[bytes] = None, ) -> None: super(ScrapyProxyH2Agent, self).__init__( reactor=reactor, pool=pool, context_factory=context_factory, connect_timeout=connect_timeout, bind_address=bind_address, ) self._proxy_uri = proxy_uri def get_endpoint(self, uri: URI): return self.endpoint_factory.endpointForURI(self._proxy_uri) def get_key(self, uri: URI) -> Tuple: """We use the proxy uri instead of uri obtained from request url""" return "http-proxy", self._proxy_uri.host, self._proxy_uri.port