Skip to content

orchestr8.oauth_flow

Base class for OAuth flow implementation.

class MyOAuthFlow(OAuthFlow):

    @property
    def auth_url(self) -> str:
        return (
            f"https://service.com/api/oauth2/authorize"
            f"?client_id={self.client_id}&redirect_uri={self.quoted_redirect_url}"
            f"&response_type=code&scope={self.user_scopes}"
        )

    def _generate_access_token(self, code: str) -> str:
        response = requests.post(
            "https://service.com/api/oauth2/token",
            data={
                'grant_type': 'authorization_code',
                'code': code,
                'redirect_uri': self.redirect_url,
            },
            auth=(self.client_id, self.client_secret)
        )

        if response.status_code != 200:
            raise Exception(f"Failed to obtain access token: {response.json()}")

        return response.json()['access_token']

my_oauth_flow = MyOAuthFlow(
    client_id="<client-id>",
    client_secret="<client-secret>",
    user_scopes="identify messages.read",
)
my_oauth_flow.authorize(timeout=30)
Source code in orchestr8/oauth_flow/__init__.py
class OAuthFlow(Logger):
    """Base class for OAuth flow implementation.

    ```python
    class MyOAuthFlow(OAuthFlow):

        @property
        def auth_url(self) -> str:
            return (
                f"https://service.com/api/oauth2/authorize"
                f"?client_id={self.client_id}&redirect_uri={self.quoted_redirect_url}"
                f"&response_type=code&scope={self.user_scopes}"
            )

        def _generate_access_token(self, code: str) -> str:
            response = requests.post(
                "https://service.com/api/oauth2/token",
                data={
                    'grant_type': 'authorization_code',
                    'code': code,
                    'redirect_uri': self.redirect_url,
                },
                auth=(self.client_id, self.client_secret)
            )

            if response.status_code != 200:
                raise Exception(f"Failed to obtain access token: {response.json()}")

            return response.json()['access_token']

    my_oauth_flow = MyOAuthFlow(
        client_id="<client-id>",
        client_secret="<client-secret>",
        user_scopes="identify messages.read",
    )
    my_oauth_flow.authorize(timeout=30)
    ```
    """

    def __init_subclass__(cls, **kwargs: Any) -> None:
        super().__init_subclass__(**kwargs)
        cdict = cls.__dict__
        if not (attr := cdict.get("auth_url", None)) or not isinstance(attr, property):
            raise TypeError(f"{cls.__name__} must define 'auth_url' property.")

        if not (attr := cdict.get("_generate_access_token", None)) or not callable(attr):
            raise TypeError(f"{cls.__name__} must define '_generate_access_token' method.")

    def __new__(cls, *args: Any, **kwargs: Any) -> OAuthFlow:
        if cls is OAuthFlow:
            raise TypeError("OAuthFlow cannot be instantiated directly")
        return super().__new__(cls)

    def __init__(
        self,
        *,
        client_id: str,
        client_secret: str,
        user_scopes: str | None = None,
        redirect_port: int | None = None,
        **kwargs: Any,
    ) -> None:
        """
        Args:
            client_id: Client ID
            client_secret: Client secret
            user_scopes: User scopes
            redirect_port: Port to redirect on. Default: `41539`
            kwargs: Additional keyword arguments
        """
        self.client_id = client_id
        self.client_secret = client_secret
        self.user_scopes = quote_plus(user_scopes or "")
        self.kwargs = kwargs

        self._redirect_server = RedirectServer(port=redirect_port)

    @property
    def redirect_url(self) -> str:
        """URL of locally hosted server."""
        if not self._redirect_server.is_running:
            self._redirect_server.start()
        return self._redirect_server.url

    @property
    def quoted_redirect_url(self) -> str:
        """Quoted redirect url"""
        return quote_plus(self.redirect_url)

    @property
    def auth_url(self) -> str:
        """Formatted authorization URL for getting code."""
        raise NotImplementedError("Property should be implemented by inherited class.")

    def _get_auth_code(self, *, timeout: int | None = None) -> str:
        """
        Get authorization code. If timeout is None, it blocks until a request is intercepted.

        Args:
            timeout: seconds to wait for. Default: `None`

        Returns:
            Authorization code
        """

        self.logger.info(
            "<w>Add this URL to your application's redirect settings:</w> <u><le>{}</le></u>", self.redirect_url
        )
        self.logger.info("<w>Click this URL to authorize:</w> <u><le>{}</le></u>", self.auth_url)

        if not (req_url := self._redirect_server.intercept(timeout=timeout)):
            raise TimeoutError(
                "No authorization request was intercepted. " "The request may have timed out or was never received."
            )

        if codes := parse_qs(urlparse(req_url).query).get("code"):
            return codes[0]

        raise ConnectionAbortedError("Authorization failed or was denied. Please try again.")

    def _generate_access_token(self, code: str) -> str:
        """
        Generate access token

        Args:
            code: Authorization code

        Returns:
            Access token
        """
        raise NotImplementedError("Method should be implemented by inherited class.")

    def authorize(self, *, timeout: int | None = None) -> str:
        """
        Authorize and return access token.

        Args:
            timeout: seconds to wait for. Default: `None`

        Returns:
            Access token
        """
        self.logger.info("Starting authorization process")
        code = self._get_auth_code(timeout=timeout)
        self.logger.success("Authorization code received")
        token = self._generate_access_token(code)
        self.logger.success("Access token generated successfully")
        return token

    def __del__(self) -> None:
        if self._redirect_server.is_running:
            self._redirect_server.stop()

auth_url: str property

Formatted authorization URL for getting code.

quoted_redirect_url: str property

Quoted redirect url

redirect_url: str property

URL of locally hosted server.

authorize

Authorize and return access token.

Parameters:

Name Type Description Default
timeout int | None

seconds to wait for. Default: None

None

Returns:

Type Description
str

Access token

Source code in orchestr8/oauth_flow/__init__.py
def authorize(self, *, timeout: int | None = None) -> str:
    """
    Authorize and return access token.

    Args:
        timeout: seconds to wait for. Default: `None`

    Returns:
        Access token
    """
    self.logger.info("Starting authorization process")
    code = self._get_auth_code(timeout=timeout)
    self.logger.success("Authorization code received")
    token = self._generate_access_token(code)
    self.logger.success("Access token generated successfully")
    return token

RedirectServer

Host a redirect server locally.

with RedirectServer(port=8080) as server:
    print(server.url) # https://localhost:8080/
    print(server.is_running) # True
    print(server.intercept(timeout=10)) # https://localhost:8080/?code=123456
Source code in orchestr8/oauth_flow/redirect_server.py
class RedirectServer(Logger):
    """
    Host a redirect server locally.

    ```python
    with RedirectServer(port=8080) as server:
        print(server.url) # https://localhost:8080/
        print(server.is_running) # True
        print(server.intercept(timeout=10)) # https://localhost:8080/?code=123456
    ```
    """

    def __init__(
        self, *, port: int | None = None, success_message: str | None = None, error_message: str | None = None
    ) -> None:
        """
        Args:
            port: Port to host on. Default: `41539`
            success_message: Message to display when successfully authorized.
            error_message: Message to display when request is denied or failed.
        """
        self.__instance: WSGIServer = None  # type: ignore[assignment]
        self.__host = "localhost"
        self.__port = port or 41539
        self.__app = RedirectWSGIApp(
            success_message=success_message or AUTH_SUCCESS_MESSAGE, error_message=error_message or AUTH_ERROR_MESSAGE
        )

    def __enter__(self) -> RedirectServer:
        self.start()
        return self

    def start(self) -> None:
        """Start the server."""
        if self.is_running:
            return

        if not MKCERT_LOCALHOST_SSL_CERT_FILE.is_file() or not MKCERT_LOCALHOST_SSL_PKEY_FILE.is_file():
            raise FileNotFoundError(
                "Certificate or Private key file not found.\n\n"
                "Run the following commands to generate it:\n"
                "mkcert -install # Skip if already done.\n"
                f'mkcert -cert-file "{MKCERT_LOCALHOST_SSL_CERT_FILE!s}" '
                f'-key-file "{MKCERT_LOCALHOST_SSL_PKEY_FILE!s}" localhost'
            )

        WSGIServer.allow_reuse_address = False
        server_instance = make_server(
            self.__host,
            self.__port,
            self.__app,  # type: ignore[arg-type]
            handler_class=_WSGIRequestHandler,
        )

        sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        sslctx.check_hostname = False
        sslctx.load_cert_chain(certfile=MKCERT_LOCALHOST_SSL_CERT_FILE, keyfile=MKCERT_LOCALHOST_SSL_PKEY_FILE)
        server_instance.socket = sslctx.wrap_socket(sock=server_instance.socket, server_side=True)
        self.__instance = server_instance
        self.logger.info(f"Redirect server started on <u>{self.__host}:{self.__port}</u>")

    @property
    def is_running(self) -> bool:
        """Check if server is running."""
        return bool(self.__instance)

    def raise_if_not_running(self) -> None:
        """Raise if server is not running."""
        if not self.is_running:
            raise RuntimeError("Server not running. use `.start()` method.")

    @property
    def url(self) -> str:
        """URL of locally hosted server."""
        self.raise_if_not_running()
        return f"https://{self.__host}:{self.__port}/"

    def intercept(self, *, timeout: int | None = None) -> str | None:
        """
        Intercept incoming request and return its URL.
        If `timeout` is None, server waits until a request is intercepted.

        Args:
            timeout: Seconds to wait for. Default: `None`

        Returns:
            Intercepted request URL if any, else None
        """
        self.raise_if_not_running()
        time_out_fmt = f"{timeout}s" if timeout is not None else "nil"
        self.logger.info(f"Intercepting request (timeout: <u>{time_out_fmt}</u>)")

        self.__instance.timeout = timeout
        self.__instance.handle_request()
        return self.__app.last_request_url

    def stop(self) -> None:
        """Stop the server."""
        if self.is_running:
            self.logger.info("Shutting down the server...")
            self.__instance.server_close()
            self.__instance = None  # type: ignore[assignment]

    def __del__(self) -> None:
        self.stop()

    def __exit__(self, exc_type: Any, exc_value: Any, tb: Any) -> None:
        self.stop()

is_running: bool property

Check if server is running.

url: str property

URL of locally hosted server.

intercept

Intercept incoming request and return its URL. If timeout is None, server waits until a request is intercepted.

Parameters:

Name Type Description Default
timeout int | None

Seconds to wait for. Default: None

None

Returns:

Type Description
str | None

Intercepted request URL if any, else None

Source code in orchestr8/oauth_flow/redirect_server.py
def intercept(self, *, timeout: int | None = None) -> str | None:
    """
    Intercept incoming request and return its URL.
    If `timeout` is None, server waits until a request is intercepted.

    Args:
        timeout: Seconds to wait for. Default: `None`

    Returns:
        Intercepted request URL if any, else None
    """
    self.raise_if_not_running()
    time_out_fmt = f"{timeout}s" if timeout is not None else "nil"
    self.logger.info(f"Intercepting request (timeout: <u>{time_out_fmt}</u>)")

    self.__instance.timeout = timeout
    self.__instance.handle_request()
    return self.__app.last_request_url

raise_if_not_running

Raise if server is not running.

Source code in orchestr8/oauth_flow/redirect_server.py
def raise_if_not_running(self) -> None:
    """Raise if server is not running."""
    if not self.is_running:
        raise RuntimeError("Server not running. use `.start()` method.")

start

Start the server.

Source code in orchestr8/oauth_flow/redirect_server.py
def start(self) -> None:
    """Start the server."""
    if self.is_running:
        return

    if not MKCERT_LOCALHOST_SSL_CERT_FILE.is_file() or not MKCERT_LOCALHOST_SSL_PKEY_FILE.is_file():
        raise FileNotFoundError(
            "Certificate or Private key file not found.\n\n"
            "Run the following commands to generate it:\n"
            "mkcert -install # Skip if already done.\n"
            f'mkcert -cert-file "{MKCERT_LOCALHOST_SSL_CERT_FILE!s}" '
            f'-key-file "{MKCERT_LOCALHOST_SSL_PKEY_FILE!s}" localhost'
        )

    WSGIServer.allow_reuse_address = False
    server_instance = make_server(
        self.__host,
        self.__port,
        self.__app,  # type: ignore[arg-type]
        handler_class=_WSGIRequestHandler,
    )

    sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    sslctx.check_hostname = False
    sslctx.load_cert_chain(certfile=MKCERT_LOCALHOST_SSL_CERT_FILE, keyfile=MKCERT_LOCALHOST_SSL_PKEY_FILE)
    server_instance.socket = sslctx.wrap_socket(sock=server_instance.socket, server_side=True)
    self.__instance = server_instance
    self.logger.info(f"Redirect server started on <u>{self.__host}:{self.__port}</u>")

stop

Stop the server.

Source code in orchestr8/oauth_flow/redirect_server.py
def stop(self) -> None:
    """Stop the server."""
    if self.is_running:
        self.logger.info("Shutting down the server...")
        self.__instance.server_close()
        self.__instance = None  # type: ignore[assignment]