blob: 45862ef22e690b4351a3b397fec838d7f6921d8c [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2023 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""The server side code of corp-ssh-helper-helper.
The server process waits for connection requests from client processes.
When receiving a request, it runs corp-ssh-helper to perform network IO on
behalf of the client.
"""
import argparse
import array
import json
import logging
import os
import re
import socket
import subprocess
import sys
import threading
from typing import Any, Dict, Optional, Sequence
_MAX_DATA_SIZE = 4096
_SCRIPT_NAME = os.path.basename(__file__)
def _find_chromiumos_checkout_root(path: str) -> Optional[str]:
"""Returns the ChromiumOS checkout root path.
Returns the checkout root path if the given path belongs to a ChromiumOS
checkout. Otherwise returns None.
"""
# Give up and return None when reaching the file system root.
if path == "/":
return None
# Try to find some known subdirectories under the checkout root.
is_chromiumos_checkout_root = True
for child_name in ["chroot", "chromite", "src"]:
child_path = os.path.join(path, child_name)
if not os.path.exists(child_path):
is_chromiumos_checkout_root = False
if is_chromiumos_checkout_root:
# This is the checkout root. Return the path.
return path
# This doesn't look like a checkout root. Move one level up and retry.
return _find_chromiumos_checkout_root(os.path.dirname(path))
def _daemonize() -> None:
"""Starts running as a background process.
This function uses the well known double fork technique to keep running the
process even after the original process exits.
"""
# First fork.
if os.fork():
# Exit the original process.
sys.exit(0)
# The first forked process starts its own session.
os.setsid()
# Second fork.
if os.fork():
# Exit the first forked process.
sys.exit(0)
# Override standard IO with /dev/null and continue running the second forked
# process.
with open(os.devnull, "r+b") as f:
os.dup2(f.fileno(), sys.stdin.fileno())
os.dup2(f.fileno(), sys.stdout.fileno())
os.dup2(f.fileno(), sys.stderr.fileno())
def _handle_request(
conn: socket.socket, msg: Dict[str, Any], fds: Sequence[int]
) -> None:
"""Handles an incoming connection request from a client."""
with conn:
# Run corp-ssh-helper with the passed parameters.
with (
os.fdopen(fds[0], "rb") as stdin,
os.fdopen(fds[1], "wb") as stdout,
os.fdopen(fds[2], "wb") as stderr,
):
# Check if the given host name is a valid IPv4 address, IPv6
# address, or a host name which satisfies the spec described in
# `man hosts`.
host = msg["host"]
if (
not re.match(r"^[0-9.]+$", host)
and not re.match(r"^[0-9a-fA-F:]+$", host) # IPv4 address
and not re.match( # IPv6 address
r"^[A-Za-z][A-Za-z0-9-.]*$", host
)
): # host name
raise RuntimeError(f'"{host}" is not a valid host name')
port = msg["port"]
if not re.match(r"^\d+$", port):
raise RuntimeError(f'"{port}" is not a valid port number')
args = ["corp-ssh-helper"]
if msg["dest4"]:
args.append("-dest4")
if msg["dest6"]:
args.append("-dest6")
if msg["relay"]:
args.extend(("-relay", msg["relay"]))
if msg["proxy_mode"]:
args.extend(("--proxy-mode", msg["proxy_mode"]))
if msg["dst_username"]:
args.extend(("-dst_username", msg["dst_username"]))
args.append(host)
args.append(port)
logging.info(
'Connecting to %s:%s. proxycommand="%s"',
host,
port,
" ".join(args),
)
returncode = subprocess.run(
args, stdin=stdin, stdout=stdout, stderr=stderr, check=False
).returncode
# After the proxy command exits, send a response JSON to the client.
conn.send(json.dumps({"returncode": returncode}).encode("utf-8"))
def main() -> int:
"""The main function."""
logging.basicConfig(level=logging.INFO)
# The named UNIX domain socket created by this process should be used only
# by the same user.
os.umask(0o077)
arg_parser = argparse.ArgumentParser(prog=_SCRIPT_NAME)
arg_parser.add_argument(
"--foreground", action="store_true", help="Run in the foreground"
)
arg_parser.add_argument(
"--kill", action="store_true", help="Kill the existing server process"
)
args = arg_parser.parse_args()
# Try to find the checkout root from the current directory's ancestors.
chromeos_path = _find_chromiumos_checkout_root(os.getcwd())
if not chromeos_path:
logging.error("Please run this script under a ChromiumOS checkout.")
return 1
socket_path = os.path.join(chromeos_path, ".corp-ssh-helper-helper.sock")
if os.path.exists(socket_path):
# Try connecting to the existing socket to check if there is an existing
# server process.
with socket.socket(socket.AF_UNIX) as s:
try:
s.connect(socket_path)
except ConnectionRefusedError:
# ConnectionRefusedError means there is no existing server
# process.
if args.kill:
logging.error(
"--kill was specified, but there is no "
"existing process to kill."
)
return 1
# Delete the existing socket and continue.
os.remove(socket_path)
else:
# If connect() doesn't raise an exception, that means there is a
# running server process.
if args.kill:
# Send a kill request to the existing server process.
msg_bytes = json.dumps({"kill": True}).encode("utf-8")
s.sendmsg([msg_bytes])
logging.info(
"Sent a kill request to the existing server process."
)
return 0
logging.error(
"The server is already running. If you want "
"to stop it, please use --kill."
)
return 1
# Start listening on the socket.
with socket.socket(socket.AF_UNIX) as s:
s.bind(socket_path)
s.listen()
logging.info("Listening on %s", socket_path)
# After this, run in the background unless running with --foreground.
if not args.foreground:
_daemonize()
# The main loop of this server process.
while True:
# Wait for a client connection.
conn, _ = s.accept()
# Receive the request message and standard IO FDs.
fds = array.array("i")
msg_bytes, ancdata, _, _ = conn.recvmsg(
_MAX_DATA_SIZE, socket.CMSG_LEN(3 * fds.itemsize)
)
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if (
cmsg_level == socket.SOL_SOCKET
and cmsg_type == socket.SCM_RIGHTS
):
fds.frombytes(
cmsg_data[
: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)
]
)
# Interpret the message as a JSON.
try:
msg = json.loads(msg_bytes)
# When receiving a kill request, just exit.
if msg.get("kill"):
logging.info("Received a kill request. Going to exit.")
return 0
# Start a new thread to handle the connection request.
threading.Thread(
target=_handle_request, args=(conn, msg, fds)
).start()
except json.decoder.JSONDecodeError as err:
logging.error("Invalid message from the client: %s", err)
if __name__ == "__main__":
sys.exit(main())