blob: 1dcc8288ce6832810fd649967550d0f96e136e2b [file]
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Shared event management helpers for generated WebDriver BiDi modules.
``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted
identically into every generated module that exposes events. Rather than
duplicating this logic across those modules, they are defined once here and
copied into generated outputs by Bazel.
"""
from __future__ import annotations
import threading
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from selenium.webdriver.common.bidi.session import Session
@dataclass
class EventConfig:
"""Configuration for a BiDi event."""
event_key: str
bidi_event: str
event_class: type
class _EventWrapper:
"""Wrapper to provide event_class attribute for WebSocketConnection callbacks."""
def __init__(self, bidi_event: str, event_class: type):
self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class
self._python_class = event_class # Keep reference to Python dataclass for deserialization
def from_json(self, params: dict) -> Any:
"""Deserialize event params into the wrapped Python dataclass.
Args:
params: Raw BiDi event params with camelCase keys.
Returns:
An instance of the dataclass, or the raw dict on failure.
"""
if self._python_class is None or self._python_class is dict:
return params
try:
# Delegate to a classmethod from_json if the class defines one
if hasattr(self._python_class, "from_json") and callable(self._python_class.from_json):
return self._python_class.from_json(params)
import dataclasses as dc
snake_params = {self._camel_to_snake(k): v for k, v in params.items()}
if dc.is_dataclass(self._python_class):
valid_fields = {f.name for f in dc.fields(self._python_class)}
filtered = {k: v for k, v in snake_params.items() if k in valid_fields}
return self._python_class(**filtered)
return self._python_class(**snake_params)
except Exception:
return params
@staticmethod
def _camel_to_snake(name: str) -> str:
result = [name[0].lower()]
for char in name[1:]:
if char.isupper():
result.extend(["_", char.lower()])
else:
result.append(char)
return "".join(result)
class _EventManager:
"""Manages event subscriptions and callbacks."""
def __init__(self, conn, event_configs: dict[str, EventConfig]):
self.conn = conn
self.event_configs = event_configs
self.subscriptions: dict = {}
self._event_wrappers = {} # Cache of _EventWrapper objects
self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()}
self._available_events = ", ".join(sorted(event_configs.keys()))
self._subscription_lock = threading.Lock()
# Create event wrappers for each event
for config in event_configs.values():
wrapper = _EventWrapper(config.bidi_event, config.event_class)
self._event_wrappers[config.bidi_event] = wrapper
def validate_event(self, event: str) -> EventConfig:
event_config = self.event_configs.get(event)
if not event_config:
raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}")
return event_config
def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None:
"""Subscribe to a BiDi event if not already subscribed."""
with self._subscription_lock:
if bidi_event not in self.subscriptions:
session = Session(self.conn)
result = session.subscribe([bidi_event], contexts=contexts)
sub_id = result.get("subscription") if isinstance(result, dict) else None
self.subscriptions[bidi_event] = {
"callbacks": [],
"subscription_id": sub_id,
}
def unsubscribe_from_event(self, bidi_event: str) -> None:
"""Unsubscribe from a BiDi event if no more callbacks exist."""
with self._subscription_lock:
entry = self.subscriptions.get(bidi_event)
if entry is not None and not entry["callbacks"]:
session = Session(self.conn)
sub_id = entry.get("subscription_id")
if sub_id:
session.unsubscribe(subscriptions=[sub_id])
else:
session.unsubscribe(events=[bidi_event])
del self.subscriptions[bidi_event]
def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None:
with self._subscription_lock:
self.subscriptions[bidi_event]["callbacks"].append(callback_id)
def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None:
with self._subscription_lock:
entry = self.subscriptions.get(bidi_event)
if entry and callback_id in entry["callbacks"]:
entry["callbacks"].remove(callback_id)
def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int:
event_config = self.validate_event(event)
# Use the event wrapper for add_callback
event_wrapper = self._event_wrappers.get(event_config.bidi_event)
callback_id = self.conn.add_callback(event_wrapper, callback)
self.subscribe_to_event(event_config.bidi_event, contexts)
self.add_callback_to_tracking(event_config.bidi_event, callback_id)
return callback_id
def remove_event_handler(self, event: str, callback_id: int) -> None:
event_config = self.validate_event(event)
event_wrapper = self._event_wrappers.get(event_config.bidi_event)
self.conn.remove_callback(event_wrapper, callback_id)
self.remove_callback_from_tracking(event_config.bidi_event, callback_id)
self.unsubscribe_from_event(event_config.bidi_event)
def clear_event_handlers(self) -> None:
"""Clear all event handlers."""
with self._subscription_lock:
if not self.subscriptions:
return
session = Session(self.conn)
for bidi_event, entry in list(self.subscriptions.items()):
event_wrapper = self._event_wrappers.get(bidi_event)
callbacks = entry["callbacks"] if isinstance(entry, dict) else entry
if event_wrapper:
for callback_id in callbacks:
self.conn.remove_callback(event_wrapper, callback_id)
sub_id = entry.get("subscription_id") if isinstance(entry, dict) else None
if sub_id:
session.unsubscribe(subscriptions=[sub_id])
else:
session.unsubscribe(events=[bidi_event])
self.subscriptions.clear()