# Copyright (c) 2012, 2023, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA # mypy: disable-error-code="attr-defined" """Module implementing low-level socket communication with MySQL servers. """ import os import socket import struct import warnings import zlib from abc import ABC, abstractmethod from collections import deque try: import ssl TLS_VERSIONS = { "TLSv1": ssl.PROTOCOL_TLSv1, "TLSv1.1": ssl.PROTOCOL_TLSv1_1, "TLSv1.2": ssl.PROTOCOL_TLSv1_2, } # TLSv1.3 included in PROTOCOL_TLS, but PROTOCOL_TLS is not included on 3.4 TLS_VERSIONS["TLSv1.3"] = ( ssl.PROTOCOL_TLS if hasattr(ssl, "PROTOCOL_TLS") else ssl.PROTOCOL_SSLv23 # Alias of PROTOCOL_TLS ) TLS_V1_3_SUPPORTED = hasattr(ssl, "HAS_TLSv1_3") and ssl.HAS_TLSv1_3 except ImportError: # If import fails, we don't have SSL support. TLS_V1_3_SUPPORTED = False from typing import Any, Deque, List, Optional, Tuple, Union from .errors import InterfaceError, NotSupportedError, OperationalError from .types import StrOrBytesPath MIN_COMPRESS_LENGTH: int = 50 MAX_PAYLOAD_LENGTH: int = 2**24 - 1 PACKET_HEADER_LENGTH: int = 4 COMPRESSED_PACKET_HEADER_LENGTH: int = 7 def _strioerror(err: IOError) -> str: """Reformat the IOError error message. This function reformats the IOError error message. """ return str(err) if not err.errno else f"{err.errno} {err.strerror}" class NetworkBroker(ABC): """Broker class interface. The network object is a broker used as a delegate by a socket object. Whenever the socket wants to deliver or get packets to or from the MySQL server it needs to rely on its network broker (netbroker). The netbroker sends `payloads` and receives `packets`. A packet is a bytes sequence, it has a header and body (referred to as payload). The first `PACKET_HEADER_LENGTH` or `COMPRESSED_PACKET_HEADER_LENGTH` (as appropriate) bytes correspond to the `header`, the remaining ones represent the `payload`. The maximum payload length allowed to be sent per packet to the server is `MAX_PAYLOAD_LENGTH`. When `send` is called with a payload whose length is greater than `MAX_PAYLOAD_LENGTH` the netbroker breaks it down into packets, so the caller of `send` can provide payloads of arbitrary length. Finally, data received by the netbroker comes directly from the server, expect to get a packet for each call to `recv`. The received packet contains a header and payload, the latter respecting `MAX_PAYLOAD_LENGTH`. """ @abstractmethod def send( self, sock: socket.socket, address: str, payload: bytes, packet_number: Optional[int] = None, compressed_packet_number: Optional[int] = None, ) -> None: """Send `payload` to the MySQL server. If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is broken down into packets. Args: sock: Object holding the socket connection. address: Socket's location. payload: Packet's body to send. packet_number: Sequence id (packet ID) to attach to the header when sending plain packets. compressed_packet_number: Same as `packet_number` but used when sending compressed packets. Raises: :class:`OperationalError`: If something goes wrong while sending packets to the MySQL server. """ @abstractmethod def recv(self, sock: socket.socket, address: str) -> bytearray: """Get the next available packet from the MySQL server. Args: sock: Object holding the socket connection. address: Socket's location. Returns: packet: A packet from the MySQL server. Raises: :class:`OperationalError`: If something goes wrong while receiving packets from the MySQL server. :class:`InterfaceError`: If something goes wrong while receiving packets from the MySQL server. """ class NetworkBrokerPlain(NetworkBroker): """Broker class for MySQL socket communication.""" def __init__(self) -> None: self._pktnr: int = -1 # packet number def _set_next_pktnr(self) -> None: """Increment packet id.""" self._pktnr = (self._pktnr + 1) % 256 def _send_pkt(self, sock: socket.socket, address: str, pkt: bytes) -> None: """Write packet to the comm channel.""" try: sock.sendall(pkt) except IOError as err: raise OperationalError( errno=2055, values=(address, _strioerror(err)) ) from err except AttributeError as err: raise OperationalError(errno=2006) from err def _recv_chunk(self, sock: socket.socket, size: int = 0) -> bytearray: """Read `size` bytes from the comm channel.""" pkt = bytearray(size) pkt_view = memoryview(pkt) while size: read = sock.recv_into(pkt_view, size) if read == 0 and size > 0: raise InterfaceError(errno=2013) pkt_view = pkt_view[read:] size -= read return pkt def send( self, sock: socket.socket, address: str, payload: bytes, packet_number: Optional[int] = None, compressed_packet_number: Optional[int] = None, ) -> None: """Send payload to the MySQL server. If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is broken down into packets. """ if packet_number is None: self._set_next_pktnr() else: self._pktnr = packet_number # If the payload is larger than or equal to MAX_PAYLOAD_LENGTH # the length is set to 2^24 - 1 (ff ff ff) and additional # packets are sent with the rest of the payload until the # payload of a packet is less than MAX_PAYLOAD_LENGTH. if len(payload) >= MAX_PAYLOAD_LENGTH: offset = 0 for _ in range(len(payload) // MAX_PAYLOAD_LENGTH): # payload_len, sequence_id, payload self._send_pkt( sock, address, b"\xff\xff\xff" + struct.pack(" bytearray: """Receive `one` packet from the MySQL server.""" try: # Read the header of the MySQL packet header = self._recv_chunk(sock, size=PACKET_HEADER_LENGTH) # Pull the payload length and sequence id payload_len, self._pktnr = ( struct.unpack(" None: super().__init__() self._compressed_pktnr = -1 self._queue_read: Deque[bytearray] = deque() @staticmethod def _prepare_packets(payload: bytes, pktnr: int) -> List[bytes]: """Prepare a payload for sending to the MySQL server.""" pkts = [] # If the payload is larger than or equal to MAX_PAYLOAD_LENGTH # the length is set to 2^24 - 1 (ff ff ff) and additional # packets are sent with the rest of the payload until the # payload of a packet is less than MAX_PAYLOAD_LENGTH. if len(payload) >= MAX_PAYLOAD_LENGTH: offset = 0 for _ in range(len(payload) // MAX_PAYLOAD_LENGTH): # payload length + sequence id + payload pkts.append( b"\xff\xff\xff" + struct.pack(" None: """Increment packet id.""" self._compressed_pktnr = (self._compressed_pktnr + 1) % 256 def _send_pkt(self, sock: socket.socket, address: str, pkt: bytes) -> None: """Compress packet and write it to the comm channel.""" compressed_pkt = zlib.compress(pkt) pkt = ( struct.pack(" None: """Send `payload` as compressed packets to the MySQL server. If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is broken down into packets. """ # get next packet numbers if packet_number is None: self._set_next_pktnr() else: self._pktnr = packet_number if compressed_packet_number is None: self._set_next_compressed_pktnr() else: self._compressed_pktnr = compressed_packet_number payload_prep = bytearray(b"").join(self._prepare_packets(payload, self._pktnr)) if len(payload) >= MAX_PAYLOAD_LENGTH - PACKET_HEADER_LENGTH: # sending a MySQL payload of the size greater or equal to 2^24 - 5 # via compression leads to at least one extra compressed packet # WHY? let's say len(payload) is MAX_PAYLOAD_LENGTH - 3; when preparing # the payload, a header of size PACKET_HEADER_LENGTH is pre-appended # to the payload. This means that len(payload_prep) is # MAX_PAYLOAD_LENGTH - 3 + PACKET_HEADER_LENGTH = MAX_PAYLOAD_LENGTH + 1 # surpassing the maximum allowed payload size per packet. offset = 0 # send several MySQL packets for _ in range(len(payload_prep) // MAX_PAYLOAD_LENGTH): self._send_pkt( sock, address, payload_prep[offset : offset + MAX_PAYLOAD_LENGTH] ) self._set_next_compressed_pktnr() offset += MAX_PAYLOAD_LENGTH self._send_pkt(sock, address, payload_prep[offset:]) else: # send one MySQL packet # For small packets it may be too costly to compress the packet. # Usually payloads less than 50 bytes (MIN_COMPRESS_LENGTH) # aren't compressed (see MySQL source code Documentation). if len(payload) > MIN_COMPRESS_LENGTH: # perform compression self._send_pkt(sock, address, payload_prep) else: # skip compression super()._send_pkt( sock, address, struct.pack(" None: """Handle reading of a compressed packet.""" # compressed_pll stands for compressed payload length. # Recalling that if uncompressed payload length == 0, the packet # comes in uncompressed, so no decompression is needed. compressed_pkt = super()._recv_chunk(sock, size=compressed_pll) pkt = ( compressed_pkt if uncompressed_pll == 0 else bytearray(zlib.decompress(compressed_pkt)) ) offset = 0 while offset < len(pkt): # pll stands for payload length pll = struct.unpack( " len(pkt) - offset: # More bytes need to be consumed # Read the header of the next MySQL packet header = super()._recv_chunk(sock, size=COMPRESSED_PACKET_HEADER_LENGTH) # compressed payload length, sequence id, uncompressed payload length ( compressed_pll, self._compressed_pktnr, uncompressed_pll, ) = ( struct.unpack(" bytearray: """Receive `one` or `several` packets from the MySQL server, enqueue them, and return the packet at the head. """ if not self._queue_read: try: # Read the header of the next MySQL packet header = super()._recv_chunk(sock, size=COMPRESSED_PACKET_HEADER_LENGTH) # compressed payload length, sequence id, uncompressed payload length ( compressed_pll, self._compressed_pktnr, uncompressed_pll, ) = ( struct.unpack(" None: """Network layer where transactions are made with plain (uncompressed) packets is enabled by default. """ # holds the socket connection self.sock: Optional[socket.socket] = None self._connection_timeout: Optional[int] = None self.server_host: Optional[str] = None self._netbroker: NetworkBroker = NetworkBrokerPlain() def switch_to_compressed_mode(self) -> None: """Enable network layer where transactions are made with compressed packets.""" self._netbroker = NetworkBrokerCompressed() def shutdown(self) -> None: """Shut down the socket before closing it.""" try: self.sock.shutdown(socket.SHUT_RDWR) self.sock.close() except (AttributeError, OSError): pass def close_connection(self) -> None: """Close the socket.""" try: self.sock.close() except (AttributeError, OSError): pass def __del__(self) -> None: self.shutdown() def set_connection_timeout(self, timeout: Optional[int]) -> None: """Set the connection timeout.""" self._connection_timeout = timeout if self.sock: self.sock.settimeout(timeout) def switch_to_ssl( self, ca: StrOrBytesPath, cert: StrOrBytesPath, key: StrOrBytesPath, verify_cert: bool = False, verify_identity: bool = False, cipher_suites: Optional[str] = None, tls_versions: Optional[List[str]] = None, ) -> None: """Switch the socket to use SSL""" if not self.sock: raise InterfaceError(errno=2048) try: if verify_cert: cert_reqs = ssl.CERT_REQUIRED elif verify_identity: cert_reqs = ssl.CERT_OPTIONAL else: cert_reqs = ssl.CERT_NONE if tls_versions is None or not tls_versions: context = ssl.create_default_context() if not verify_identity: context.check_hostname = False else: tls_versions.sort(reverse=True) tls_version = tls_versions[0] if ( not TLS_V1_3_SUPPORTED and tls_version == "TLSv1.3" and len(tls_versions) > 1 ): tls_version = tls_versions[1] ssl_protocol = TLS_VERSIONS[tls_version] context = ssl.SSLContext(ssl_protocol) if tls_version == "TLSv1.3": if "TLSv1.2" not in tls_versions: context.options |= ssl.OP_NO_TLSv1_2 if "TLSv1.1" not in tls_versions: context.options |= ssl.OP_NO_TLSv1_1 if "TLSv1" not in tls_versions: context.options |= ssl.OP_NO_TLSv1 context.check_hostname = False context.verify_mode = cert_reqs context.load_default_certs() if ca: try: context.load_verify_locations(ca) except (IOError, ssl.SSLError) as err: self.sock.close() raise InterfaceError(f"Invalid CA Certificate: {err}") from err if cert: try: context.load_cert_chain(cert, key) except (IOError, ssl.SSLError) as err: self.sock.close() raise InterfaceError(f"Invalid Certificate/Key: {err}") from err if cipher_suites: context.set_ciphers(cipher_suites) if hasattr(self, "server_host"): self.sock = context.wrap_socket( self.sock, server_hostname=self.server_host ) else: self.sock = context.wrap_socket(self.sock) if verify_identity: context.check_hostname = True hostnames: List[str] = [self.server_host] if self.server_host else [] if os.name == "nt" and self.server_host == "localhost": hostnames = ["localhost", "127.0.0.1"] aliases = socket.gethostbyaddr(self.server_host) hostnames.extend([aliases[0]] + aliases[1]) match_found = False errs = [] for hostname in hostnames: try: # Deprecated in Python 3.7 without a replacement and # should be removed in the future, since OpenSSL now # performs hostname matching # pylint: disable=deprecated-method ssl.match_hostname(self.sock.getpeercert(), hostname) # pylint: enable=deprecated-method except ssl.CertificateError as err: errs.append(str(err)) else: match_found = True break if not match_found: self.sock.close() raise InterfaceError( f"Unable to verify server identity: {', '.join(errs)}" ) except NameError as err: raise NotSupportedError("Python installation has no SSL support") from err except (ssl.SSLError, IOError) as err: raise InterfaceError( errno=2055, values=(self.address, _strioerror(err)) ) from err except ssl.CertificateError as err: raise InterfaceError(str(err)) from err except NotImplementedError as err: raise InterfaceError(str(err)) from err def send( self, payload: bytes, packet_number: Optional[int] = None, compressed_packet_number: Optional[int] = None, ) -> None: """Send `payload` to the MySQL server.""" return self._netbroker.send( self.sock, self.address, payload, packet_number=packet_number, compressed_packet_number=compressed_packet_number, ) def recv(self) -> bytearray: """Get packet from the MySQL server comm channel.""" return self._netbroker.recv(self.sock, self.address) @abstractmethod def open_connection(self) -> None: """Open the socket.""" @property @abstractmethod def address(self) -> str: """Get the location of the socket.""" class MySQLUnixSocket(MySQLSocket): """MySQL socket class using UNIX sockets. Opens a connection through the UNIX socket of the MySQL Server. """ def __init__(self, unix_socket: str = "/tmp/mysql.sock") -> None: super().__init__() self.unix_socket: str = unix_socket self._address: str = unix_socket @property def address(self) -> str: return self._address def open_connection(self) -> None: try: self.sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM # pylint: disable=no-member ) self.sock.settimeout(self._connection_timeout) self.sock.connect(self.unix_socket) except IOError as err: raise InterfaceError( errno=2002, values=(self.address, _strioerror(err)) ) from err except Exception as err: raise InterfaceError(str(err)) from err def switch_to_ssl( self, *args: Any, **kwargs: Any # pylint: disable=unused-argument ) -> None: """Switch the socket to use SSL.""" warnings.warn( "SSL is disabled when using unix socket connections", Warning, ) class MySQLTCPSocket(MySQLSocket): """MySQL socket class using TCP/IP. Opens a TCP/IP connection to the MySQL Server. """ def __init__( self, host: str = "127.0.0.1", port: int = 3306, force_ipv6: bool = False, ) -> None: super().__init__() self.server_host: str = host self.server_port: int = port self.force_ipv6: bool = force_ipv6 self._family: int = 0 self._address: str = f"{host}:{port}" @property def address(self) -> str: return self._address def open_connection(self) -> None: """Open the TCP/IP connection to the MySQL server.""" # pylint: disable=no-member # Get address information addrinfo: Union[ Tuple[None, None, None, None, None], Tuple[ socket.AddressFamily, socket.SocketKind, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]], ], ] = (None, None, None, None, None) try: addrinfos = socket.getaddrinfo( self.server_host, self.server_port, 0, socket.SOCK_STREAM, socket.SOL_TCP, ) # If multiple results we favor IPv4, unless IPv6 was forced. for info in addrinfos: if self.force_ipv6 and info[0] == socket.AF_INET6: addrinfo = info break if info[0] == socket.AF_INET: addrinfo = info break if self.force_ipv6 and addrinfo[0] is None: raise InterfaceError(f"No IPv6 address found for {self.server_host}") if addrinfo[0] is None: addrinfo = addrinfos[0] except IOError as err: raise InterfaceError( errno=2003, values=(self.address, _strioerror(err)) ) from err (self._family, socktype, proto, _, sockaddr) = addrinfo # Instanciate the socket and connect try: self.sock = socket.socket(self._family, socktype, proto) self.sock.settimeout(self._connection_timeout) self.sock.connect(sockaddr) except IOError as err: raise InterfaceError( errno=2003, values=( self.server_host, self.server_port, _strioerror(err), ), ) from err except Exception as err: raise OperationalError(str(err)) from err