# This module is based on the excellent work by Adam Bartoš who # provided a lot of what went into the implementation here in # the discussion to issue1602 in the Python bug tracker. # # There are some general differences in regards to how this works # compared to the original patches as we do not need to patch # the entire interpreter but just work in our little world of # echo and prompt. import io import sys import time import typing as t from ctypes import byref from ctypes import c_char from ctypes import c_char_p from ctypes import c_int from ctypes import c_ssize_t from ctypes import c_ulong from ctypes import c_void_p from ctypes import POINTER from ctypes import py_object from ctypes import Structure from ctypes.wintypes import DWORD from ctypes.wintypes import HANDLE from ctypes.wintypes import LPCWSTR from ctypes.wintypes import LPWSTR from ._compat import _NonClosingTextIOWrapper assert sys.platform == "win32" import msvcrt # noqa: E402 from ctypes import windll # noqa: E402 from ctypes import WINFUNCTYPE # noqa: E402 c_ssize_p = POINTER(c_ssize_t) kernel32 = windll.kernel32 GetStdHandle = kernel32.GetStdHandle ReadConsoleW = kernel32.ReadConsoleW WriteConsoleW = kernel32.WriteConsoleW GetConsoleMode = kernel32.GetConsoleMode GetLastError = kernel32.GetLastError GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32)) CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))( ("CommandLineToArgvW", windll.shell32) ) LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32)) STDIN_HANDLE = GetStdHandle(-10) STDOUT_HANDLE = GetStdHandle(-11) STDERR_HANDLE = GetStdHandle(-12) PyBUF_SIMPLE = 0 PyBUF_WRITABLE = 1 ERROR_SUCCESS = 0 ERROR_NOT_ENOUGH_MEMORY = 8 ERROR_OPERATION_ABORTED = 995 STDIN_FILENO = 0 STDOUT_FILENO = 1 STDERR_FILENO = 2 EOF = b"\x1a" MAX_BYTES_WRITTEN = 32767 try: from ctypes import pythonapi except ImportError: # On PyPy we cannot get buffers so our ability to operate here is # severely limited. get_buffer = None else: class Py_buffer(Structure): _fields_ = [ ("buf", c_void_p), ("obj", py_object), ("len", c_ssize_t), ("itemsize", c_ssize_t), ("readonly", c_int), ("ndim", c_int), ("format", c_char_p), ("shape", c_ssize_p), ("strides", c_ssize_p), ("suboffsets", c_ssize_p), ("internal", c_void_p), ] PyObject_GetBuffer = pythonapi.PyObject_GetBuffer PyBuffer_Release = pythonapi.PyBuffer_Release def get_buffer(obj, writable=False): buf = Py_buffer() flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE PyObject_GetBuffer(py_object(obj), byref(buf), flags) try: buffer_type = c_char * buf.len return buffer_type.from_address(buf.buf) finally: PyBuffer_Release(byref(buf)) class _WindowsConsoleRawIOBase(io.RawIOBase): def __init__(self, handle): self.handle = handle def isatty(self): super().isatty() return True class _WindowsConsoleReader(_WindowsConsoleRawIOBase): def readable(self): return True def readinto(self, b): bytes_to_be_read = len(b) if not bytes_to_be_read: return 0 elif bytes_to_be_read % 2: raise ValueError( "cannot read odd number of bytes from UTF-16-LE encoded console" ) buffer = get_buffer(b, writable=True) code_units_to_be_read = bytes_to_be_read // 2 code_units_read = c_ulong() rv = ReadConsoleW( HANDLE(self.handle), buffer, code_units_to_be_read, byref(code_units_read), None, ) if GetLastError() == ERROR_OPERATION_ABORTED: # wait for KeyboardInterrupt time.sleep(0.1) if not rv: raise OSError(f"Windows error: {GetLastError()}") if buffer[0] == EOF: return 0 return 2 * code_units_read.value class _WindowsConsoleWriter(_WindowsConsoleRawIOBase): def writable(self): return True @staticmethod def _get_error_message(errno): if errno == ERROR_SUCCESS: return "ERROR_SUCCESS" elif errno == ERROR_NOT_ENOUGH_MEMORY: return "ERROR_NOT_ENOUGH_MEMORY" return f"Windows error {errno}" def write(self, b): bytes_to_be_written = len(b) buf = get_buffer(b) code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2 code_units_written = c_ulong() WriteConsoleW( HANDLE(self.handle), buf, code_units_to_be_written, byref(code_units_written), None, ) bytes_written = 2 * code_units_written.value if bytes_written == 0 and bytes_to_be_written > 0: raise OSError(self._get_error_message(GetLastError())) return bytes_written class ConsoleStream: def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None: self._text_stream = text_stream self.buffer = byte_stream @property def name(self) -> str: return self.buffer.name def write(self, x: t.AnyStr) -> int: if isinstance(x, str): return self._text_stream.write(x) try: self.flush() except Exception: pass return self.buffer.write(x) def writelines(self, lines: t.Iterable[t.AnyStr]) -> None: for line in lines: self.write(line) def __getattr__(self, name: str) -> t.Any: return getattr(self._text_stream, name) def isatty(self) -> bool: return self.buffer.isatty() def __repr__(self): return f"" def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO: text_stream = _NonClosingTextIOWrapper( io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)), "utf-16-le", "strict", line_buffering=True, ) return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO: text_stream = _NonClosingTextIOWrapper( io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)), "utf-16-le", "strict", line_buffering=True, ) return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO: text_stream = _NonClosingTextIOWrapper( io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)), "utf-16-le", "strict", line_buffering=True, ) return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) _stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = { 0: _get_text_stdin, 1: _get_text_stdout, 2: _get_text_stderr, } def _is_console(f: t.TextIO) -> bool: if not hasattr(f, "fileno"): return False try: fileno = f.fileno() except (OSError, io.UnsupportedOperation): return False handle = msvcrt.get_osfhandle(fileno) return bool(GetConsoleMode(handle, byref(DWORD()))) def _get_windows_console_stream( f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str] ) -> t.Optional[t.TextIO]: if ( get_buffer is not None and encoding in {"utf-16-le", None} and errors in {"strict", None} and _is_console(f) ): func = _stream_factories.get(f.fileno()) if func is not None: b = getattr(f, "buffer", None) if b is None: return None return func(b)