11import ipaddress
22import socket
3- from urllib .parse import urlparse
43
54import httpx
65
7- __http_timeout = 15
8-
96
107class ForcedTimeoutException (Exception ):
118 """
@@ -23,14 +20,12 @@ class InvalidDomainError(Exception):
2320 ...
2421
2522
26- def resolve_ip_with_socket (url ):
23+ def resolve_ip_with_socket (domain_name : str ):
2724 """
2825 Resolve the IP address of a given URL. If the URL is invalid,
2926 return None.
3027 """
3128 try :
32- parsed_url = urlparse (url )
33- domain_name = parsed_url .netloc
3429 ip_address = socket .gethostbyname (domain_name )
3530 return ip_address
3631 except (socket .gaierror , ValueError ):
@@ -74,19 +69,23 @@ class AsyncSafeTransport(httpx.AsyncBaseTransport):
7469 and that the request is not made to a local IP address.
7570 """
7671
72+ timeout : int = 15
73+
7774 def __init__ (self , ** kwargs ):
75+ self .timeout = kwargs .pop ("timeout" , self .timeout )
7876 self ._wrapper = httpx .AsyncHTTPTransport (** kwargs )
7977
8078 async def handle_async_request (self , request ):
8179 # override timeout value for _all_ requests
82- request .extensions ["timeout" ] = httpx .Timeout (__http_timeout )
80+ request .extensions ["timeout" ] = httpx .Timeout (self . timeout , pool = self . timeout ). as_dict ( )
8381
8482 # validate the request is not attempting to connect to a local IP
8583 # This is a security measure to prevent SSRF attacks
8684
87- ip_address = resolve_ip_with_socket (request .url )
85+ ip_address = resolve_ip_with_socket (str ( request .url . netloc ) )
8886
8987 if ip_address and is_local_ip (ip_address ):
88+ print ("HERE, I'M HERE" )
9089 raise InvalidDomainError (f"invalid request on local resource: { request .url } -> { ip_address } " )
9190
9291 return await self ._wrapper .handle_async_request (request )
0 commit comments