Refactor proxy server code to improve readability

and maintainability
This commit is contained in:
Jose Henrique 2023-12-06 17:44:29 -03:00
parent 12f4176fca
commit 870b630027
2 changed files with 87 additions and 41 deletions

View File

@ -7,6 +7,7 @@ import logging
import logging.handlers import logging.handlers
from logging import config from logging import config
CONNECTION_TIMEOUT_SEC = 10
BUFFER_SIZE = 32 * 1024 BUFFER_SIZE = 32 * 1024
CURRENT_THREADS = 0 CURRENT_THREADS = 0
MAX_THREADS = 50 MAX_THREADS = 50
@ -86,66 +87,111 @@ class Server:
self.thread_check() self.thread_check()
CURRENT_THREADS += 1 CURRENT_THREADS += 1
thread = Thread(target = self.handle_connection, args = (conn, client_addr, )) client_connection = ClientConnection(conn, client_addr)
thread = Thread(target = client_connection.handle_connection)
CURRENT_THREADS -= 1 CURRENT_THREADS -= 1
thread.start() thread.start()
def __del__(self): def __del__(self):
self.sock.close() self.sock.close()
def handle_connection(self, client_socket, client_address): class ClientConnection:
request = client_socket.recv(BUFFER_SIZE) def __init__(self, client_socket, client_address):
logger = Logger.instance() self.client_socket = client_socket
self.client_addr = client_address[0]
self.client_port = client_address[1]
if len(request) == 0: self.server_socket = None
client_socket.close()
return
try: def check_monitorando(self, request):
raw_request = request.decode() if b"monitorando" in request:
except UnicodeDecodeError: body = open('forbidden.html', 'r').read()
client_socket.close() self.client_socket.sendall(b"HTTP/1.1 403 Forbidden\r\n\r\n")
return self.client_socket.sendall(body.encode())
self.client_socket.close()
return True
if "CONNECT" in raw_request: return False
client_socket.sendall(b"HTTP/1.1 200 Connection Established\r\n\r\n")
request = client_socket.recv(BUFFER_SIZE)
request_url = raw_request.split(' ')[1] def get_host_port(self, request):
request_url = request.split(' ')[1]
request_host = "" request_host = ""
request_port = 443 if 'https' in request_url else 80 request_port = 443 if request_url.startswith('https') else 80
# Remove protocolo do request
if request_url.startswith('http'): if request_url.startswith('http'):
request_host = request_url.split('/')[2] request_host = request_url.split('//')[1]
else: else:
request_host = request_url.split('/')[0] request_host = request_url.split('/')[0]
# Remove 'www' do request
if request_host.startswith('www'): if request_host.startswith('www'):
request_host = request_host[4:] request_host = request_host.split('www.')[1]
# Remove porta do request e verifica se é um host com porta
if ':' in request_host: if ':' in request_host:
request_port = int(request_host.split(':')[1]) request_port = int(request_host.split(':')[1])
request_host = request_host.split(':')[0] request_host = request_host.split(':')[0]
if "monitorando" in request_url.lower(): return request_host, request_port
body = open('forbidden.html', 'r').read()
client_socket.sendall(b"HTTP/1.1 403 Forbidden\r\n\r\n") def handle_connection(self):
client_socket.sendall(body.encode()) request = self.client_socket.recv(BUFFER_SIZE)
client_socket.close() logger = Logger.instance()
logger.info(f"REQUEST [{client_address[0]}:{client_address[1]}] to [{request_host}:{request_port}] - 'Monitorando' - 403 Forbidden")
# Verifica se o request é vazio, se for, fecha o socket
if len(request) == 0:
self.client_socket.close()
return return
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try:
server_socket.connect((request_host, request_port)) # Tenta decodificar o request, se não conseguir, fecha o socket
server_socket.send(request) raw_request = request.decode()
except UnicodeDecodeError:
self.client_socket.close()
return
# Verifica se é um request CONNECT
# Se for, retorna 200 Connection Established, e espera o request novamente
if "CONNECT" in raw_request:
self.client_socket.sendall(b"HTTP/1.1 200 Connection Established\r\n\r\n")
request = self.client_socket.recv(BUFFER_SIZE)
# Retorna host e porta do request
request_host, request_port = self.get_host_port(raw_request)
# Verifica se 'monitorando' está no request
if self.check_monitorando(request):
logger.info(f"REQUEST [{self.client_addr}:{self.client_port}] to [{request_host}:{request_port}] - 'Monitorando' - 403 Forbidden")
return
# Cria socket para o servidor
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.connect((request_host, request_port))
self.server_socket.send(request)
'''
Enquanto houver dados para serem lidos, do socket do cliente e envia para o servidor
e do socket do servidor e envia para o cliente
Se não houver dados para serem lidos, fecha os sockets
Exemplo de fluxo:
proxy -> server
server -> proxy
proxy -> client
proxy -> client
proxy -> client
server -> proxy
'''
data = None
while True: while True:
triple = select.select([client_socket, server_socket], [], [], 10)[0] triple = select.select([self.client_socket, self.server_socket], [], [], CONNECTION_TIMEOUT_SEC)[0]
if not len(triple): if not len(triple):
break break
try: try:
if server_socket in triple: if self.server_socket in triple:
data = server_socket.recv(BUFFER_SIZE) data = self.server_socket.recv(BUFFER_SIZE)
if not data: if not data:
break break
@ -153,23 +199,23 @@ class Server:
status_code = data.decode().split('\r\n')[0].split(' ')[1:] status_code = data.decode().split('\r\n')[0].split(' ')[1:]
status_code = ' '.join(status_code) status_code = ' '.join(status_code)
if is_valid_status_code(status_code): if is_valid_status_code(status_code):
logger.info(f"REQUEST [{client_address[0]}:{client_address[1]}] to [{request_host}:{request_port}] - {status_code}") logger.info(f"REQUEST [{self.client_addr}:{self.client_port}] to [{request_host}:{request_port}] - {status_code}")
except UnicodeDecodeError: except UnicodeDecodeError:
logger.info(f"REQUEST [{client_address[0]}:{client_address[1]}] to [{request_host}:{request_port}]") logger.info(f"REQUEST [{self.client_addr}:{self.client_port}] to [{request_host}:{request_port}]")
pass pass
client_socket.send(data) self.client_socket.send(data)
if client_socket in triple: if self.client_socket in triple:
data = client_socket.recv(BUFFER_SIZE) data = self.client_socket.recv(BUFFER_SIZE)
if not data: if not data:
break break
server_socket.send(data) self.server_socket.send(data)
except ConnectionAbortedError: except ConnectionAbortedError:
break break
server_socket.close() self.server_socket.close()
client_socket.close() self.client_socket.close()
def is_valid_status_code(status_code: str): def is_valid_status_code(status_code: str):
valid_starts = [str(i) for i in range(5)] valid_starts = [str(i) for i in range(5)]

View File

@ -1 +1 @@
02c9fe27bfc80f2a71e6a7b618c379fcd60b18c4e8cc646d081b1eebb7682bba 7fb3b97ffb3311c5fec9d1f37b7c9726a2f5283947eb026676723ac49a527e44