Merge remote-tracking branch 'origin/trunk' into router-cache
diff --git a/java/src/org/openqa/selenium/firefox/GeckoDriverService.java b/java/src/org/openqa/selenium/firefox/GeckoDriverService.java index 6563519..2a8720b 100644 --- a/java/src/org/openqa/selenium/firefox/GeckoDriverService.java +++ b/java/src/org/openqa/selenium/firefox/GeckoDriverService.java
@@ -147,6 +147,8 @@ public static class Builder private @Nullable FirefoxDriverLogLevel logLevel; private @Nullable Boolean logTruncate; private @Nullable File profileRoot; + private @Nullable Integer marionettePort; + private @Nullable Integer websocketPort; @Override public int score(Capabilities capabilities) { @@ -204,6 +206,31 @@ public GeckoDriverService.Builder withProfileRoot(@Nullable File root) { return this; } + /** + * Configures geckodriver to connect to an existing Firefox instance via the specified + * Marionette port. + * + * @param marionettePort The port where Marionette is listening on the existing Firefox + * instance. + * @return A self reference. + */ + public GeckoDriverService.Builder connectToExisting(int marionettePort) { + this.marionettePort = marionettePort; + return this; + } + + /** + * Configures the WebSocket port for BiDi. A value of 0 will automatically allocate a free port. + * + * @param websocketPort The port to use for WebSocket communication, or 0 for automatic + * allocation. + * @return A self reference. + */ + public GeckoDriverService.Builder withWebSocketPort(@Nullable Integer websocketPort) { + this.websocketPort = websocketPort; + return this; + } + @Override protected void loadSystemProperties() { parseLogOutput(GECKO_DRIVER_LOG_PROPERTY); @@ -229,13 +256,27 @@ protected List<String> createArgs() { List<String> args = new ArrayList<>(); args.add(String.format(Locale.ROOT, "--port=%d", getPort())); - int wsPort = PortProber.findFreePort(); - args.add(String.format("--websocket-port=%d", wsPort)); + // Check if marionette port is specified via connectToExisting method + if (marionettePort != null) { + args.add("--connect-existing"); + args.add("--marionette-port"); + args.add(String.valueOf(marionettePort)); + } else { + // Configure websocket port for BiDi communication + if (websocketPort != null) { + args.add("--websocket-port"); + args.add(String.valueOf(websocketPort)); - args.add("--allow-origins"); - args.add(String.format("http://127.0.0.1:%d", wsPort)); - args.add(String.format("http://localhost:%d", wsPort)); - args.add(String.format("http://[::1]:%d", wsPort)); + args.add("--allow-origins"); + args.add(String.format("http://127.0.0.1:%d", websocketPort)); + args.add(String.format("http://localhost:%d", websocketPort)); + args.add(String.format("http://[::1]:%d", websocketPort)); + } else { + // Use 0 to auto-allocate a free port + args.add("--websocket-port"); + args.add("0"); + } + } if (logLevel != null) { args.add("--log");
diff --git a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java index bc448ae..609a8a2 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java +++ b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java
@@ -539,7 +539,7 @@ public void run() { // up starving a session request. Map<Capabilities, Long> stereotypes = getAvailableNodes().stream() - .filter(node -> node.hasCapacity()) + .filter(NodeStatus::hasCapacity) .flatMap(node -> node.getSlots().stream().map(Slot::getStereotype)) .collect( Collectors.groupingBy(ImmutableCapabilities::copyOf, Collectors.counting()));
diff --git a/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java b/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java index 54ef10b..097657a 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java +++ b/java/src/org/openqa/selenium/grid/distributor/local/LocalNodeRegistry.java
@@ -361,9 +361,8 @@ public Set<NodeStatus> getAvailableNodes() { readLock.lock(); try { return model.getSnapshot().stream() - .filter( - node -> - !DOWN.equals(node.getAvailability()) && !DRAINING.equals(node.getAvailability())) + // Filter nodes are UP and have capacity (available slots) + .filter(node -> UP.equals(node.getAvailability()) && node.hasCapacity()) .collect(ImmutableSet.toImmutableSet()); } finally { readLock.unlock();
diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java b/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java index 2b88bdc..66be163 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java
@@ -55,6 +55,7 @@ import org.openqa.selenium.grid.data.TraceSessionRequest; import org.openqa.selenium.grid.distributor.config.DistributorOptions; import org.openqa.selenium.grid.jmx.JMXHelper; +import org.openqa.selenium.grid.jmx.MBean; import org.openqa.selenium.grid.jmx.ManagedAttribute; import org.openqa.selenium.grid.jmx.ManagedService; import org.openqa.selenium.grid.log.LoggingOptions; @@ -110,6 +111,7 @@ public class LocalNewSessionQueue extends NewSessionQueue implements Closeable { thread.setName(NAME); return thread; }); + private final MBean jmxBean; public LocalNewSessionQueue( Tracer tracer, @@ -139,7 +141,8 @@ public LocalNewSessionQueue( requestTimeoutCheck.toMillis(), MILLISECONDS); - new JMXHelper().register(this); + // Manage JMX and unregister on close() + this.jmxBean = new JMXHelper().register(this); } public static NewSessionQueue create(Config config) { @@ -502,6 +505,10 @@ public boolean isReady() { @Override public void close() { shutdownGracefully(NAME, service); + + if (jmxBean != null) { + new JMXHelper().unregister(jmxBean.getObjectName()); + } } private void failDueToTimeout(RequestId reqId) {
diff --git a/java/test/org/openqa/selenium/firefox/FirefoxDriverConcurrentTest.java b/java/test/org/openqa/selenium/firefox/FirefoxDriverConcurrentTest.java index 3b92e5b..9b78c1d 100644 --- a/java/test/org/openqa/selenium/firefox/FirefoxDriverConcurrentTest.java +++ b/java/test/org/openqa/selenium/firefox/FirefoxDriverConcurrentTest.java
@@ -28,6 +28,7 @@ import org.openqa.selenium.ParallelTestRunner.Worker; import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebElement; +import org.openqa.selenium.bidi.BiDi; import org.openqa.selenium.testing.JupiterTestBase; import org.openqa.selenium.testing.drivers.WebDriverBuilder; @@ -164,4 +165,69 @@ void shouldBeAbleToUseTheSameProfileMoreThanOnce() { if (two != null) two.quit(); } } + + @Test + void multipleFirefoxInstancesWithBiDiEnabledCanRunSimultaneously() { + // Create two Firefox instances with BiDi enabled, should use different ports + FirefoxOptions options1 = new FirefoxOptions().enableBiDi(); + FirefoxOptions options2 = new FirefoxOptions().enableBiDi(); + + WebDriver driver1 = null; + WebDriver driver2 = null; + + try { + driver1 = new WebDriverBuilder().get(options1); + BiDi biDi1 = ((FirefoxDriver) driver1).getBiDi(); + assertThat(biDi1).isNotNull(); + + // Extract the BiDi websocket URL and port for the first instance + String webSocketUrl1 = + (String) ((FirefoxDriver) driver1).getCapabilities().getCapability("webSocketUrl"); + String port1 = webSocketUrl1.replaceAll("^ws://[^:]+:(\\d+)/.*$", "$1"); + + driver2 = new WebDriverBuilder().get(options2); + BiDi biDi2 = ((FirefoxDriver) driver2).getBiDi(); + assertThat(biDi2).isNotNull(); + + // Extract the BiDi websocket URL and port for the second instance + String webSocketUrl2 = + (String) ((FirefoxDriver) driver2).getCapabilities().getCapability("webSocketUrl"); + String port2 = webSocketUrl2.replaceAll("^ws://[^:]+:(\\d+)/.*$", "$1"); + + // Verify that the ports are different + assertThat(port1).isNotEqualTo(port2); + } finally { + // Clean up + if (driver1 != null) { + driver1.quit(); + } + if (driver2 != null) { + driver2.quit(); + } + } + } + + @Test + void geckoDriverServiceConnectToExistingFirefox() { + GeckoDriverService.Builder builder = new GeckoDriverService.Builder(); + + // Test connectToExisting method + builder.connectToExisting(2829); + GeckoDriverService service = builder.build(); + + assertThat(service).isNotNull(); + service.stop(); + } + + @Test + void geckoDriverServiceCustomWebSocketPort() { + GeckoDriverService.Builder builder = new GeckoDriverService.Builder(); + + // Test withWebSocketPort method + builder.withWebSocketPort(9225); + GeckoDriverService service = builder.build(); + + assertThat(service).isNotNull(); + service.stop(); + } }
diff --git a/java/test/org/openqa/selenium/grid/distributor/local/LocalDistributorTest.java b/java/test/org/openqa/selenium/grid/distributor/local/LocalDistributorTest.java index a0092c9..eae8de5 100644 --- a/java/test/org/openqa/selenium/grid/distributor/local/LocalDistributorTest.java +++ b/java/test/org/openqa/selenium/grid/distributor/local/LocalDistributorTest.java
@@ -494,6 +494,351 @@ void slowStartingNodesShouldNotCauseReservationsToBeSerialized() { assertThat(System.currentTimeMillis() - start).isLessThan(delay * 2); } + @Test + void shouldOnlyReturnNodesWithFreeSlots() throws URISyntaxException { + // Create a distributor + NewSessionQueue queue = + new LocalNewSessionQueue( + tracer, + new DefaultSlotMatcher(), + Duration.ofSeconds(2), + Duration.ofSeconds(2), + Duration.ofSeconds(1), + registrationSecret, + 5); + LocalDistributor distributor = + new LocalDistributor( + tracer, + bus, + new PassthroughHttpClient.Factory(localNode), + new LocalSessionMap(tracer, bus), + queue, + new DefaultSlotSelector(), + registrationSecret, + Duration.ofMinutes(5), + false, + Duration.ofSeconds(5), + newSessionThreadPoolSize, + new DefaultSlotMatcher(), + Duration.ofSeconds(30)); + + // Create two nodes - both initially have free slots + URI nodeUri1 = new URI("http://example:1234"); + URI nodeUri2 = new URI("http://example:5678"); + + // Node 1: Has free slots + Node node1 = + LocalNode.builder(tracer, bus, nodeUri1, nodeUri1, registrationSecret) + .add( + new ImmutableCapabilities("browserName", "cheese"), + new TestSessionFactory( + (id, c) -> + new Session(id, nodeUri1, new ImmutableCapabilities(), c, Instant.now()))) + .build(); + + // Node 2: Will be fully occupied + Node node2 = + LocalNode.builder(tracer, bus, nodeUri2, nodeUri2, registrationSecret) + .add( + new ImmutableCapabilities("browserName", "cheese"), + new TestSessionFactory( + (id, c) -> + new Session(id, nodeUri2, new ImmutableCapabilities(), c, Instant.now()))) + .build(); + + // Add both nodes to distributor + distributor.add(node1); + distributor.add(node2); + + // Initially both nodes should be available + Set<NodeStatus> initialAvailableFreeNodes = distributor.getAvailableNodes(); + assertThat(initialAvailableFreeNodes).hasSize(2); + + // Create a session to occupy one slot + SessionRequest sessionRequest = + new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "cheese")), + Map.of(), + Map.of()); + + // Create session - this will occupy one slot on one of the nodes + distributor.newSession(sessionRequest); + + // Now test getAvailableNodes - should return nodes that still have free slots + Set<NodeStatus> availableFreeNodes = distributor.getAvailableNodes(); + + // Both nodes should still be available since each has only 1 slot and we created 1 session + // But let's verify the logic by checking that all returned nodes have free slots + for (NodeStatus nodeStatus : availableFreeNodes) { + assertThat(nodeStatus.getAvailability()).isEqualTo(UP); + + // Verify node has at least one free slot + boolean hasFreeSlot = + nodeStatus.getSlots().stream().anyMatch(slot -> slot.getSession() == null); + assertThat(hasFreeSlot).isTrue(); + } + + // Create another session to fully occupy both nodes + SessionRequest sessionRequest2 = + new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "cheese")), + Map.of(), + Map.of()); + + distributor.newSession(sessionRequest2); + + // Now both nodes should be fully occupied, so getAvailableNodes should return empty + Set<NodeStatus> fullyOccupiedNodes = distributor.getAvailableNodes(); + assertThat(fullyOccupiedNodes).isEmpty(); + } + + @Test + void shouldNotReturnDrainingNodes() throws URISyntaxException { + // Create a distributor + NewSessionQueue queue = + new LocalNewSessionQueue( + tracer, + new DefaultSlotMatcher(), + Duration.ofSeconds(2), + Duration.ofSeconds(2), + Duration.ofSeconds(1), + registrationSecret, + 5); + LocalDistributor distributor = + new LocalDistributor( + tracer, + bus, + new PassthroughHttpClient.Factory(localNode), + new LocalSessionMap(tracer, bus), + queue, + new DefaultSlotSelector(), + registrationSecret, + Duration.ofMinutes(5), + false, + Duration.ofSeconds(5), + newSessionThreadPoolSize, + new DefaultSlotMatcher(), + Duration.ofSeconds(30)); + + // Create a node + URI nodeUri = new URI("http://example:1234"); + Node node = + LocalNode.builder(tracer, bus, nodeUri, nodeUri, registrationSecret) + .add( + new ImmutableCapabilities("browserName", "cheese"), + new TestSessionFactory( + (id, c) -> + new Session(id, nodeUri, new ImmutableCapabilities(), c, Instant.now()))) + .build(); + + // Add node to distributor + distributor.add(node); + + // Initially, node should be available + Set<NodeStatus> availableFreeNodes = distributor.getAvailableNodes(); + assertThat(availableFreeNodes).hasSize(1); + assertThat(availableFreeNodes.iterator().next().getAvailability()).isEqualTo(UP); + + // Drain the node + distributor.drain(node.getId()); + + // After draining, node should not be returned by getAvailableNodes + availableFreeNodes = distributor.getAvailableNodes(); + assertThat(availableFreeNodes).isEmpty(); + } + + @Test + void shouldNotReturnDownNodes() throws URISyntaxException { + // Create a distributor + NewSessionQueue queue = + new LocalNewSessionQueue( + tracer, + new DefaultSlotMatcher(), + Duration.ofSeconds(2), + Duration.ofSeconds(2), + Duration.ofSeconds(1), + registrationSecret, + 5); + LocalDistributor distributor = + new LocalDistributor( + tracer, + bus, + new PassthroughHttpClient.Factory(localNode), + new LocalSessionMap(tracer, bus), + queue, + new DefaultSlotSelector(), + registrationSecret, + Duration.ofMinutes(5), + false, + Duration.ofSeconds(5), + newSessionThreadPoolSize, + new DefaultSlotMatcher(), + Duration.ofSeconds(30)); + + // Create a node + URI nodeUri = new URI("http://example:1234"); + Node node = + LocalNode.builder(tracer, bus, nodeUri, nodeUri, registrationSecret) + .add( + new ImmutableCapabilities("browserName", "cheese"), + new TestSessionFactory( + (id, c) -> + new Session(id, nodeUri, new ImmutableCapabilities(), c, Instant.now()))) + .build(); + + // Add node to distributor + distributor.add(node); + + // Initially, node should be available + Set<NodeStatus> availableFreeNodes = distributor.getAvailableNodes(); + assertThat(availableFreeNodes).hasSize(1); + + // Remove the node (simulates DOWN state) + distributor.remove(node.getId()); + + // After removal, node should not be returned by getAvailableNodes + availableFreeNodes = distributor.getAvailableNodes(); + assertThat(availableFreeNodes).isEmpty(); + } + + @Test + void shouldReduceRedundantSlotChecks() throws URISyntaxException { + // Create a distributor + NewSessionQueue queue = + new LocalNewSessionQueue( + tracer, + new DefaultSlotMatcher(), + Duration.ofSeconds(2), + Duration.ofSeconds(2), + Duration.ofSeconds(1), + registrationSecret, + 5); + LocalDistributor distributor = + new LocalDistributor( + tracer, + bus, + new PassthroughHttpClient.Factory(localNode), + new LocalSessionMap(tracer, bus), + queue, + new DefaultSlotSelector(), + registrationSecret, + Duration.ofMinutes(5), + false, + Duration.ofSeconds(5), + newSessionThreadPoolSize, + new DefaultSlotMatcher(), + Duration.ofSeconds(30)); + + // Create multiple nodes, some with free slots, some fully occupied + List<Node> nodes = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + URI nodeUri = new URI("http://example:" + (1234 + i)); + Node node = + LocalNode.builder(tracer, bus, nodeUri, nodeUri, registrationSecret) + .add( + new ImmutableCapabilities("browserName", "cheese"), + new TestSessionFactory( + (id, c) -> + new Session(id, nodeUri, new ImmutableCapabilities(), c, Instant.now()))) + .build(); + nodes.add(node); + distributor.add(node); + } + + // Occupy slots on first 3 nodes + for (int i = 0; i < 3; i++) { + SessionRequest sessionRequest = + new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "cheese")), + Map.of(), + Map.of()); + distributor.newSession(sessionRequest); + } + + // getAvailableNodes should only return the 2 nodes with free slots + Set<NodeStatus> availableFreeNodes = distributor.getAvailableNodes(); + assertThat(availableFreeNodes).hasSize(2); + + // Verify all returned nodes have free slots + for (NodeStatus nodeStatus : availableFreeNodes) { + boolean hasFreeSlot = + nodeStatus.getSlots().stream().anyMatch(slot -> slot.getSession() == null); + assertThat(hasFreeSlot).isTrue(); + assertThat(nodeStatus.getAvailability()).isEqualTo(UP); + } + } + + @Test + void shouldHandleAllNodesFullyOccupied() throws URISyntaxException { + // Create a distributor + NewSessionQueue queue = + new LocalNewSessionQueue( + tracer, + new DefaultSlotMatcher(), + Duration.ofSeconds(2), + Duration.ofSeconds(2), + Duration.ofSeconds(1), + registrationSecret, + 5); + LocalDistributor distributor = + new LocalDistributor( + tracer, + bus, + new PassthroughHttpClient.Factory(localNode), + new LocalSessionMap(tracer, bus), + queue, + new DefaultSlotSelector(), + registrationSecret, + Duration.ofMinutes(5), + false, + Duration.ofSeconds(5), + newSessionThreadPoolSize, + new DefaultSlotMatcher(), + Duration.ofSeconds(30)); + + // Create nodes with single slot each + List<Node> nodes = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + URI nodeUri = new URI("http://example:" + (1234 + i)); + Node node = + LocalNode.builder(tracer, bus, nodeUri, nodeUri, registrationSecret) + .add( + new ImmutableCapabilities("browserName", "cheese"), + new TestSessionFactory( + (id, c) -> + new Session(id, nodeUri, new ImmutableCapabilities(), c, Instant.now()))) + .build(); + nodes.add(node); + distributor.add(node); + } + + // Occupy all slots + for (int i = 0; i < 3; i++) { + SessionRequest sessionRequest = + new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "cheese")), + Map.of(), + Map.of()); + distributor.newSession(sessionRequest); + } + + // getAvailableNodes should return empty set when all nodes are fully occupied + Set<NodeStatus> availableFreeNodes = distributor.getAvailableNodes(); + assertThat(availableFreeNodes).isEmpty(); + } + private class Handler extends Session implements HttpHandler { private Handler(Capabilities capabilities) {
diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 27656be..9cd2913 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py
@@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import threading +from dataclasses import dataclass from typing import Any, Callable, Optional, Union from selenium.webdriver.common.bidi.common import command_builder @@ -373,55 +375,298 @@ ) -class BrowsingContextEvent: - """Base class for browsing context events.""" +class ContextCreated: + """Event class for browsingContext.contextCreated event.""" - def __init__(self, event_class: str, **kwargs): - self.event_class = event_class - self.params = kwargs + event_class = "browsingContext.contextCreated" @classmethod - def from_json(cls, json: dict) -> "BrowsingContextEvent": - """Creates a BrowsingContextEvent instance from a dictionary. + def from_json(cls, json: dict): + if isinstance(json, BrowsingContextInfo): + return json + return BrowsingContextInfo.from_json(json) + + +class ContextDestroyed: + """Event class for browsingContext.contextDestroyed event.""" + + event_class = "browsingContext.contextDestroyed" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, BrowsingContextInfo): + return json + return BrowsingContextInfo.from_json(json) + + +class NavigationStarted: + """Event class for browsingContext.navigationStarted event.""" + + event_class = "browsingContext.navigationStarted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationCommitted: + """Event class for browsingContext.navigationCommitted event.""" + + event_class = "browsingContext.navigationCommitted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationFailed: + """Event class for browsingContext.navigationFailed event.""" + + event_class = "browsingContext.navigationFailed" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationAborted: + """Event class for browsingContext.navigationAborted event.""" + + event_class = "browsingContext.navigationAborted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class DomContentLoaded: + """Event class for browsingContext.domContentLoaded event.""" + + event_class = "browsingContext.domContentLoaded" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class Load: + """Event class for browsingContext.load event.""" + + event_class = "browsingContext.load" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class FragmentNavigated: + """Event class for browsingContext.fragmentNavigated event.""" + + event_class = "browsingContext.fragmentNavigated" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class DownloadWillBegin: + """Event class for browsingContext.downloadWillBegin event.""" + + event_class = "browsingContext.downloadWillBegin" + + @classmethod + def from_json(cls, json: dict): + return DownloadWillBeginParams.from_json(json) + + +class UserPromptOpened: + """Event class for browsingContext.userPromptOpened event.""" + + event_class = "browsingContext.userPromptOpened" + + @classmethod + def from_json(cls, json: dict): + return UserPromptOpenedParams.from_json(json) + + +class UserPromptClosed: + """Event class for browsingContext.userPromptClosed event.""" + + event_class = "browsingContext.userPromptClosed" + + @classmethod + def from_json(cls, json: dict): + return UserPromptClosedParams.from_json(json) + + +class HistoryUpdated: + """Event class for browsingContext.historyUpdated event.""" + + event_class = "browsingContext.historyUpdated" + + @classmethod + def from_json(cls, json: dict): + return HistoryUpdatedParams.from_json(json) + + +@dataclass +class EventConfig: + event_key: str + bidi_event: str + event_class: type + + +class _EventManager: + """Class to manage event subscriptions and callbacks for BrowsingContext.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + # Thread safety lock for subscription operations + self._subscription_lock = threading.Lock() + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: Optional[list[str]] = None) -> None: + """Subscribe to a BiDi event if not already subscribed. Parameters: - ----------- - json: A dictionary containing the event information. - - Returns: - ------- - BrowsingContextEvent: A new instance of BrowsingContextEvent. + ---------- + bidi_event: The BiDi event name. + contexts: Optional browsing context IDs to subscribe to. """ - event_class = json.get("event_class") - if event_class is None or not isinstance(event_class, str): - raise ValueError("event_class is required and must be a string") + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) + self.subscriptions[bidi_event] = [] - return cls(event_class=event_class, **json) + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist. + + Parameters: + ---------- + bidi_event: The BiDi event name. + """ + with self._subscription_lock: + callback_list = self.subscriptions.get(bidi_event) + if callback_list is not None and not callback_list: + session = Session(self.conn) + self.conn.execute(session.unsubscribe(bidi_event)) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + callback_list = self.subscriptions.get(bidi_event) + if callback_list and callback_id in callback_list: + callback_list.remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: + event_config = self.validate_event(event) + + callback_id = self.conn.add_callback(event_config.event_class, callback) + + # Subscribe to the event if needed + self.subscribe_to_event(event_config.bidi_event, contexts) + + # Track the callback + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + + # Remove the callback from the connection + self.conn.remove_callback(event_config.event_class, callback_id) + + # Remove from tracking collections + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + + # Unsubscribe if no more callbacks exist + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers from the browsing context.""" + with self._subscription_lock: + if not self.subscriptions: + return + + session = Session(self.conn) + + for bidi_event, callback_ids in list(self.subscriptions.items()): + event_class = self._bidi_to_class.get(bidi_event) + if event_class: + # Remove all callbacks for this event + for callback_id in callback_ids: + self.conn.remove_callback(event_class, callback_id) + + self.conn.execute(session.unsubscribe(bidi_event)) + + self.subscriptions.clear() class BrowsingContext: """BiDi implementation of the browsingContext module.""" - EVENTS = { - "context_created": "browsingContext.contextCreated", - "context_destroyed": "browsingContext.contextDestroyed", - "dom_content_loaded": "browsingContext.domContentLoaded", - "download_will_begin": "browsingContext.downloadWillBegin", - "fragment_navigated": "browsingContext.fragmentNavigated", - "history_updated": "browsingContext.historyUpdated", - "load": "browsingContext.load", - "navigation_aborted": "browsingContext.navigationAborted", - "navigation_committed": "browsingContext.navigationCommitted", - "navigation_failed": "browsingContext.navigationFailed", - "navigation_started": "browsingContext.navigationStarted", - "user_prompt_closed": "browsingContext.userPromptClosed", - "user_prompt_opened": "browsingContext.userPromptOpened", + EVENT_CONFIGS = { + "context_created": EventConfig("context_created", "browsingContext.contextCreated", ContextCreated), + "context_destroyed": EventConfig("context_destroyed", "browsingContext.contextDestroyed", ContextDestroyed), + "dom_content_loaded": EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", DomContentLoaded), + "download_will_begin": EventConfig( + "download_will_begin", "browsingContext.downloadWillBegin", DownloadWillBegin + ), + "fragment_navigated": EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", FragmentNavigated), + "history_updated": EventConfig("history_updated", "browsingContext.historyUpdated", HistoryUpdated), + "load": EventConfig("load", "browsingContext.load", Load), + "navigation_aborted": EventConfig("navigation_aborted", "browsingContext.navigationAborted", NavigationAborted), + "navigation_committed": EventConfig( + "navigation_committed", "browsingContext.navigationCommitted", NavigationCommitted + ), + "navigation_failed": EventConfig("navigation_failed", "browsingContext.navigationFailed", NavigationFailed), + "navigation_started": EventConfig("navigation_started", "browsingContext.navigationStarted", NavigationStarted), + "user_prompt_closed": EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", UserPromptClosed), + "user_prompt_opened": EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", UserPromptOpened), } def __init__(self, conn): self.conn = conn - self.subscriptions = {} - self.callbacks = {} + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + @classmethod + def get_event_names(cls) -> list[str]: + """Get a list of all available event names. + + Returns: + ------- + List[str]: A list of event names that can be used with event handlers. + """ + return list(cls.EVENT_CONFIGS.keys()) def activate(self, context: str) -> None: """Activates and focuses the given top-level traversable. @@ -739,50 +984,6 @@ result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) return result - def _on_event(self, event_name: str, callback: Callable) -> int: - """Set a callback function to subscribe to a browsing context event. - - Parameters: - ---------- - event_name: The event to subscribe to. - callback: The callback function to execute on event. - - Returns: - ------- - int: callback id - """ - event = BrowsingContextEvent(event_name) - - def _callback(event_data): - if event_name == self.EVENTS["context_created"] or event_name == self.EVENTS["context_destroyed"]: - info = BrowsingContextInfo.from_json(event_data.params) - callback(info) - elif event_name == self.EVENTS["download_will_begin"]: - params = DownloadWillBeginParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["user_prompt_opened"]: - params = UserPromptOpenedParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["user_prompt_closed"]: - params = UserPromptClosedParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["history_updated"]: - params = HistoryUpdatedParams.from_json(event_data.params) - callback(params) - else: - # For navigation events - info = NavigationInfo.from_json(event_data.params) - callback(info) - - callback_id = self.conn.add_callback(event, _callback) - - if event_name in self.callbacks: - self.callbacks[event_name].append(callback_id) - else: - self.callbacks[event_name] = [callback_id] - - return callback_id - def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: """Add an event handler to the browsing context. @@ -796,24 +997,7 @@ ------- int: callback id """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") - - callback_id = self._on_event(event_name, callback) - - if event_name in self.subscriptions: - self.subscriptions[event_name].append(callback_id) - else: - params = {"events": [event_name]} - if contexts is not None: - params["browsingContexts"] = contexts - session = Session(self.conn) - self.conn.execute(session.subscribe(**params)) - self.subscriptions[event_name] = [callback_id] - - return callback_id + return self._event_manager.add_event_handler(event, callback, contexts) def remove_event_handler(self, event: str, callback_id: int) -> None: """Remove an event handler from the browsing context. @@ -823,31 +1007,8 @@ event: The event to unsubscribe from. callback_id: The callback id to remove. """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") - - event_obj = BrowsingContextEvent(event_name) - - self.conn.remove_callback(event_obj, callback_id) - if event_name in self.subscriptions: - callbacks = self.subscriptions[event_name] - if callback_id in callbacks: - callbacks.remove(callback_id) - if not callbacks: - params = {"events": [event_name]} - session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) - del self.subscriptions[event_name] + self._event_manager.remove_event_handler(event, callback_id) def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" - for event_name in self.subscriptions: - event = BrowsingContextEvent(event_name) - for callback_id in self.subscriptions[event_name]: - self.conn.remove_callback(event, callback_id) - params = {"events": [event_name]} - session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) - self.subscriptions = {} + self._event_manager.clear_event_handlers()
diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py new file mode 100644 index 0000000..8b2f40e --- /dev/null +++ b/py/selenium/webdriver/common/bidi/input.py
@@ -0,0 +1,474 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import math +from dataclasses import dataclass, field +from typing import Any, Optional, Union + +from selenium.webdriver.common.bidi.common import command_builder +from selenium.webdriver.common.bidi.session import Session + + +class PointerType: + """Represents the possible pointer types.""" + + MOUSE = "mouse" + PEN = "pen" + TOUCH = "touch" + + VALID_TYPES = {MOUSE, PEN, TOUCH} + + +class Origin: + """Represents the possible origin types.""" + + VIEWPORT = "viewport" + POINTER = "pointer" + + +@dataclass +class ElementOrigin: + """Represents an element origin for input actions.""" + + type: str + element: dict + + def __init__(self, element_reference: dict): + self.type = "element" + self.element = element_reference + + def to_dict(self) -> dict: + """Convert the ElementOrigin to a dictionary.""" + return {"type": self.type, "element": self.element} + + +@dataclass +class PointerParameters: + """Represents pointer parameters for pointer actions.""" + + pointer_type: str = PointerType.MOUSE + + def __post_init__(self): + if self.pointer_type not in PointerType.VALID_TYPES: + raise ValueError(f"Invalid pointer type: {self.pointer_type}. Must be one of {PointerType.VALID_TYPES}") + + def to_dict(self) -> dict: + """Convert the PointerParameters to a dictionary.""" + return {"pointerType": self.pointer_type} + + +@dataclass +class PointerCommonProperties: + """Common properties for pointer actions.""" + + width: int = 1 + height: int = 1 + pressure: float = 0.0 + tangential_pressure: float = 0.0 + twist: int = 0 + altitude_angle: float = 0.0 + azimuth_angle: float = 0.0 + + def __post_init__(self): + if self.width < 1: + raise ValueError("width must be at least 1") + if self.height < 1: + raise ValueError("height must be at least 1") + if not (0.0 <= self.pressure <= 1.0): + raise ValueError("pressure must be between 0.0 and 1.0") + if not (0.0 <= self.tangential_pressure <= 1.0): + raise ValueError("tangential_pressure must be between 0.0 and 1.0") + if not (0 <= self.twist <= 359): + raise ValueError("twist must be between 0 and 359") + if not (0.0 <= self.altitude_angle <= math.pi / 2): + raise ValueError("altitude_angle must be between 0.0 and π/2") + if not (0.0 <= self.azimuth_angle <= 2 * math.pi): + raise ValueError("azimuth_angle must be between 0.0 and 2π") + + def to_dict(self) -> dict: + """Convert the PointerCommonProperties to a dictionary.""" + result: dict[str, Any] = {} + if self.width != 1: + result["width"] = self.width + if self.height != 1: + result["height"] = self.height + if self.pressure != 0.0: + result["pressure"] = self.pressure + if self.tangential_pressure != 0.0: + result["tangentialPressure"] = self.tangential_pressure + if self.twist != 0: + result["twist"] = self.twist + if self.altitude_angle != 0.0: + result["altitudeAngle"] = self.altitude_angle + if self.azimuth_angle != 0.0: + result["azimuthAngle"] = self.azimuth_angle + return result + + +# Action classes +@dataclass +class PauseAction: + """Represents a pause action.""" + + duration: Optional[int] = None + + @property + def type(self) -> str: + return "pause" + + def to_dict(self) -> dict: + """Convert the PauseAction to a dictionary.""" + result: dict[str, Any] = {"type": self.type} + if self.duration is not None: + result["duration"] = self.duration + return result + + +@dataclass +class KeyDownAction: + """Represents a key down action.""" + + value: str = "" + + @property + def type(self) -> str: + return "keyDown" + + def to_dict(self) -> dict: + """Convert the KeyDownAction to a dictionary.""" + return {"type": self.type, "value": self.value} + + +@dataclass +class KeyUpAction: + """Represents a key up action.""" + + value: str = "" + + @property + def type(self) -> str: + return "keyUp" + + def to_dict(self) -> dict: + """Convert the KeyUpAction to a dictionary.""" + return {"type": self.type, "value": self.value} + + +@dataclass +class PointerDownAction: + """Represents a pointer down action.""" + + button: int = 0 + properties: Optional[PointerCommonProperties] = None + + @property + def type(self) -> str: + return "pointerDown" + + def to_dict(self) -> dict: + """Convert the PointerDownAction to a dictionary.""" + result: dict[str, Any] = {"type": self.type, "button": self.button} + if self.properties: + result.update(self.properties.to_dict()) + return result + + +@dataclass +class PointerUpAction: + """Represents a pointer up action.""" + + button: int = 0 + + @property + def type(self) -> str: + return "pointerUp" + + def to_dict(self) -> dict: + """Convert the PointerUpAction to a dictionary.""" + return {"type": self.type, "button": self.button} + + +@dataclass +class PointerMoveAction: + """Represents a pointer move action.""" + + x: float = 0 + y: float = 0 + duration: Optional[int] = None + origin: Optional[Union[str, ElementOrigin]] = None + properties: Optional[PointerCommonProperties] = None + + @property + def type(self) -> str: + return "pointerMove" + + def to_dict(self) -> dict: + """Convert the PointerMoveAction to a dictionary.""" + result: dict[str, Any] = {"type": self.type, "x": self.x, "y": self.y} + if self.duration is not None: + result["duration"] = self.duration + if self.origin is not None: + if isinstance(self.origin, ElementOrigin): + result["origin"] = self.origin.to_dict() + else: + result["origin"] = self.origin + if self.properties: + result.update(self.properties.to_dict()) + return result + + +@dataclass +class WheelScrollAction: + """Represents a wheel scroll action.""" + + x: int = 0 + y: int = 0 + delta_x: int = 0 + delta_y: int = 0 + duration: Optional[int] = None + origin: Optional[Union[str, ElementOrigin]] = Origin.VIEWPORT + + @property + def type(self) -> str: + return "scroll" + + def to_dict(self) -> dict: + """Convert the WheelScrollAction to a dictionary.""" + result: dict[str, Any] = { + "type": self.type, + "x": self.x, + "y": self.y, + "deltaX": self.delta_x, + "deltaY": self.delta_y, + } + if self.duration is not None: + result["duration"] = self.duration + if self.origin is not None: + if isinstance(self.origin, ElementOrigin): + result["origin"] = self.origin.to_dict() + else: + result["origin"] = self.origin + return result + + +# Source Actions +@dataclass +class NoneSourceActions: + """Represents a sequence of none actions.""" + + id: str = "" + actions: list[PauseAction] = field(default_factory=list) + + @property + def type(self) -> str: + return "none" + + def to_dict(self) -> dict: + """Convert the NoneSourceActions to a dictionary.""" + return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + + +@dataclass +class KeySourceActions: + """Represents a sequence of key actions.""" + + id: str = "" + actions: list[Union[PauseAction, KeyDownAction, KeyUpAction]] = field(default_factory=list) + + @property + def type(self) -> str: + return "key" + + def to_dict(self) -> dict: + """Convert the KeySourceActions to a dictionary.""" + return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + + +@dataclass +class PointerSourceActions: + """Represents a sequence of pointer actions.""" + + id: str = "" + parameters: Optional[PointerParameters] = None + actions: list[Union[PauseAction, PointerDownAction, PointerUpAction, PointerMoveAction]] = field( + default_factory=list + ) + + def __post_init__(self): + if self.parameters is None: + self.parameters = PointerParameters() + + @property + def type(self) -> str: + return "pointer" + + def to_dict(self) -> dict: + """Convert the PointerSourceActions to a dictionary.""" + result: dict[str, Any] = { + "type": self.type, + "id": self.id, + "actions": [action.to_dict() for action in self.actions], + } + if self.parameters: + result["parameters"] = self.parameters.to_dict() + return result + + +@dataclass +class WheelSourceActions: + """Represents a sequence of wheel actions.""" + + id: str = "" + actions: list[Union[PauseAction, WheelScrollAction]] = field(default_factory=list) + + @property + def type(self) -> str: + return "wheel" + + def to_dict(self) -> dict: + """Convert the WheelSourceActions to a dictionary.""" + return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + + +@dataclass +class FileDialogInfo: + """Represents file dialog information from input.fileDialogOpened event.""" + + context: str + multiple: bool + element: Optional[dict] = None + + @classmethod + def from_dict(cls, data: dict) -> "FileDialogInfo": + """Creates a FileDialogInfo instance from a dictionary. + + Parameters: + ----------- + data: A dictionary containing the file dialog information. + + Returns: + ------- + FileDialogInfo: A new instance of FileDialogInfo. + """ + return cls(context=data["context"], multiple=data["multiple"], element=data.get("element")) + + +# Event Class +class FileDialogOpened: + """Event class for input.fileDialogOpened event.""" + + event_class = "input.fileDialogOpened" + + @classmethod + def from_json(cls, json): + """Create FileDialogInfo from JSON data.""" + return FileDialogInfo.from_dict(json) + + +class Input: + """ + BiDi implementation of the input module. + """ + + def __init__(self, conn): + self.conn = conn + self.subscriptions = {} + self.callbacks = {} + + def perform_actions( + self, + context: str, + actions: list[Union[NoneSourceActions, KeySourceActions, PointerSourceActions, WheelSourceActions]], + ) -> None: + """Performs a sequence of user input actions. + + Parameters: + ----------- + context: The browsing context ID where actions should be performed. + actions: A list of source actions to perform. + """ + params = {"context": context, "actions": [action.to_dict() for action in actions]} + self.conn.execute(command_builder("input.performActions", params)) + + def release_actions(self, context: str) -> None: + """Releases all input state for the given context. + + Parameters: + ----------- + context: The browsing context ID to release actions for. + """ + params = {"context": context} + self.conn.execute(command_builder("input.releaseActions", params)) + + def set_files(self, context: str, element: dict, files: list[str]) -> None: + """Sets files for a file input element. + + Parameters: + ----------- + context: The browsing context ID. + element: The element reference (script.SharedReference). + files: A list of file paths to set. + """ + params = {"context": context, "element": element, "files": files} + self.conn.execute(command_builder("input.setFiles", params)) + + def add_file_dialog_handler(self, handler): + """Add a handler for file dialog opened events. + + Parameters: + ----------- + handler: Callback function that takes a FileDialogInfo object. + + Returns: + -------- + int: Callback ID for removing the handler later. + """ + # Subscribe to the event if not already subscribed + if FileDialogOpened.event_class not in self.subscriptions: + session = Session(self.conn) + self.conn.execute(session.subscribe(FileDialogOpened.event_class)) + self.subscriptions[FileDialogOpened.event_class] = [] + + # Add callback - the callback receives the parsed FileDialogInfo directly + callback_id = self.conn.add_callback(FileDialogOpened, handler) + + self.subscriptions[FileDialogOpened.event_class].append(callback_id) + self.callbacks[callback_id] = handler + + return callback_id + + def remove_file_dialog_handler(self, callback_id: int) -> None: + """Remove a file dialog handler. + + Parameters: + ----------- + callback_id: The callback ID returned by add_file_dialog_handler. + """ + if callback_id in self.callbacks: + del self.callbacks[callback_id] + + if FileDialogOpened.event_class in self.subscriptions: + if callback_id in self.subscriptions[FileDialogOpened.event_class]: + self.subscriptions[FileDialogOpened.event_class].remove(callback_id) + + # If no more callbacks for this event, unsubscribe + if not self.subscriptions[FileDialogOpened.event_class]: + session = Session(self.conn) + self.conn.execute(session.unsubscribe(FileDialogOpened.event_class)) + del self.subscriptions[FileDialogOpened.event_class] + + self.conn.remove_callback(FileDialogOpened, callback_id)
diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 74b8a35..50e93e1 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py
@@ -319,7 +319,7 @@ ) if result.type == "success": - return result.result + return result.result if result.result is not None else {} else: error_message = "Error while executing script" if result.exception_details:
diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index a1467fd..fa6db4b 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py
@@ -41,6 +41,7 @@ from selenium.webdriver.common.bidi.browser import Browser from selenium.webdriver.common.bidi.browsing_context import BrowsingContext from selenium.webdriver.common.bidi.emulation import Emulation +from selenium.webdriver.common.bidi.input import Input from selenium.webdriver.common.bidi.network import Network from selenium.webdriver.common.bidi.permissions import Permissions from selenium.webdriver.common.bidi.script import Script @@ -272,6 +273,7 @@ self._webextension = None self._permissions = None self._emulation = None + self._input = None self._devtools = None def __repr__(self): @@ -1420,6 +1422,29 @@ return self._emulation + @property + def input(self): + """Returns an input module object for BiDi input commands. + + Returns: + -------- + Input: an object containing access to BiDi input commands. + + Examples: + --------- + >>> from selenium.webdriver.common.bidi.input import KeySourceActions, KeyDownAction, KeyUpAction + >>> key_actions = KeySourceActions(id="keyboard", actions=[KeyDownAction(value="a"), KeyUpAction(value="a")]) + >>> driver.input.perform_actions(driver.current_window_handle, [key_actions]) + >>> driver.input.release_actions(driver.current_window_handle) + """ + if not self._websocket_connection: + self._start_bidi() + + if self._input is None: + self._input = Input(self._websocket_connection) + + return self._input + def _get_cdp_details(self): import json
diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 74a0f53..768640d 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py
@@ -16,6 +16,9 @@ # under the License. import base64 +import concurrent.futures +import threading +import time import pytest @@ -525,3 +528,553 @@ ) # The login form should have 3 input elements (email, age, and submit button) assert len(elements) == 3 + + +# Tests for event handlers + + +def test_add_event_handler_context_created(driver): + """Test adding event handler for context_created event.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + assert callback_id is not None + + # Create a new context to trigger the event + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify the event was received (might be > 1 since default context is also included) + assert len(events_received) >= 1 + assert events_received[0].context == context_id or events_received[1].context == context_id + + driver.browsing_context.close(context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id) + + +def test_add_event_handler_context_destroyed(driver): + """Test adding event handler for context_destroyed event.""" + events_received = [] + + def on_context_destroyed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_destroyed", on_context_destroyed) + assert callback_id is not None + + # Create and then close a context to trigger the event + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(context_id) + + assert len(events_received) == 1 + assert events_received[0].context == context_id + + driver.browsing_context.remove_event_handler("context_destroyed", callback_id) + + +def test_add_event_handler_navigation_committed(driver, pages): + """Test adding event handler for navigation_committed event.""" + events_received = [] + + def on_navigation_committed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_committed", on_navigation_committed) + assert callback_id is not None + + # Navigate to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) >= 1 + assert any(url in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("navigation_committed", callback_id) + + +def test_add_event_handler_dom_content_loaded(driver, pages): + """Test adding event handler for dom_content_loaded event.""" + events_received = [] + + def on_dom_content_loaded(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("dom_content_loaded", on_dom_content_loaded) + assert callback_id is not None + + # Navigate to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("dom_content_loaded", callback_id) + + +def test_add_event_handler_load(driver, pages): + """Test adding event handler for load event.""" + events_received = [] + + def on_load(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("load", on_load) + assert callback_id is not None + + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("load", callback_id) + + +def test_add_event_handler_navigation_started(driver, pages): + """Test adding event handler for navigation_started event.""" + events_received = [] + + def on_navigation_started(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_started", on_navigation_started) + assert callback_id is not None + + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("navigation_started", callback_id) + + +def test_add_event_handler_fragment_navigated(driver, pages): + """Test adding event handler for fragment_navigated event.""" + events_received = [] + + def on_fragment_navigated(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("fragment_navigated", on_fragment_navigated) + assert callback_id is not None + + # First navigate to a page + context_id = driver.current_window_handle + url = pages.url("linked_image.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + # Then navigate to the same page with a fragment to trigger the event + fragment_url = url + "#link" + driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("link" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("fragment_navigated", callback_id) + + [email protected]_firefox +def test_add_event_handler_navigation_failed(driver): + """Test adding event handler for navigation_failed event.""" + events_received = [] + + def on_navigation_failed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_failed", on_navigation_failed) + assert callback_id is not None + + # Navigate to an invalid URL to trigger the event + context_id = driver.current_window_handle + try: + driver.browsing_context.navigate(context=context_id, url="http://invalid-domain-that-does-not-exist.test/") + except Exception: + # Expect an exception due to navigation failure + pass + + assert len(events_received) == 1 + assert events_received[0].url == "http://invalid-domain-that-does-not-exist.test/" + assert events_received[0].context == context_id + + driver.browsing_context.remove_event_handler("navigation_failed", callback_id) + + +def test_add_event_handler_user_prompt_opened(driver, pages): + """Test adding event handler for user_prompt_opened event.""" + events_received = [] + + def on_user_prompt_opened(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("user_prompt_opened", on_user_prompt_opened) + assert callback_id is not None + + # Create an alert to trigger the event + create_alert_page(driver, pages) + driver.find_element(By.ID, "alert").click() + WebDriverWait(driver, 5).until(EC.alert_is_present()) + + assert len(events_received) == 1 + assert events_received[0].type == "alert" + assert events_received[0].message == "cheese" + + # Clean up the alert + driver.browsing_context.handle_user_prompt(context=driver.current_window_handle) + driver.browsing_context.remove_event_handler("user_prompt_opened", callback_id) + + +def test_add_event_handler_user_prompt_closed(driver, pages): + """Test adding event handler for user_prompt_closed event.""" + events_received = [] + + def on_user_prompt_closed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("user_prompt_closed", on_user_prompt_closed) + assert callback_id is not None + + create_prompt_page(driver, pages) + driver.execute_script("prompt('Enter something')") + WebDriverWait(driver, 5).until(EC.alert_is_present()) + + driver.browsing_context.handle_user_prompt( + context=driver.current_window_handle, accept=True, user_text="test input" + ) + + assert len(events_received) == 1 + assert events_received[0].accepted is True + assert events_received[0].user_text == "test input" + + driver.browsing_context.remove_event_handler("user_prompt_closed", callback_id) + + [email protected]_chrome [email protected]_firefox [email protected]_edge +def test_add_event_handler_history_updated(driver, pages): + """Test adding event handler for history_updated event.""" + events_received = [] + + def on_history_updated(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("history_updated", on_history_updated) + assert callback_id is not None + + # Navigate to a page and use history API to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + # Use history.pushState to trigger history updated event + driver.execute_script("history.pushState({}, '', '/new-path');") + + assert len(events_received) == 1 + assert any("/new-path" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("history_updated", callback_id) + + [email protected]_firefox +def test_add_event_handler_download_will_begin(driver, pages): + """Test adding event handler for download_will_begin event.""" + events_received = [] + + def on_download_will_begin(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("download_will_begin", on_download_will_begin) + assert callback_id is not None + + # click on a download link to trigger the event + context_id = driver.current_window_handle + url = pages.url("downloads/download.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + download_xpath_file_1_txt = '//*[@id="file-1"]' + driver.find_element(By.XPATH, download_xpath_file_1_txt).click() + WebDriverWait(driver, 5).until(lambda d: len(events_received) > 0) + + assert len(events_received) == 1 + assert events_received[0].suggested_filename == "file_1.txt" + + driver.browsing_context.remove_event_handler("download_will_begin", callback_id) + + +def test_add_event_handler_with_specific_contexts(driver): + """Test adding event handler with specific browsing contexts.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Add event handler for specific context + callback_id = driver.browsing_context.add_event_handler( + "context_created", on_context_created, contexts=[context_id] + ) + assert callback_id is not None + + # Create another context (should trigger event) + new_context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + assert len(events_received) >= 1 + + driver.browsing_context.close(context_id) + driver.browsing_context.close(new_context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id) + + +def test_remove_event_handler(driver): + """Test removing event handler.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + + # Create a context to trigger the event + context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) + + initial_events = len(events_received) + + # Remove the event handler + driver.browsing_context.remove_event_handler("context_created", callback_id) + + # Create another context (should not trigger event after removal) + context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify no new events were received after removal + assert len(events_received) == initial_events + + driver.browsing_context.close(context_id_1) + driver.browsing_context.close(context_id_2) + + +def test_multiple_event_handlers_same_event(driver): + """Test adding multiple event handlers for the same event.""" + events_received_1 = [] + events_received_2 = [] + + def on_context_created_1(info): + events_received_1.append(info) + + def on_context_created_2(info): + events_received_2.append(info) + + # Add multiple event handlers for the same event + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + + # Create a context to trigger both handlers + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify both handlers received the event + assert len(events_received_1) >= 1 + assert len(events_received_2) >= 1 + # Check any of the events has the required context ID + assert any(event.context == context_id for event in events_received_1) + assert any(event.context == context_id for event in events_received_2) + + driver.browsing_context.close(context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id_1) + driver.browsing_context.remove_event_handler("context_created", callback_id_2) + + +def test_remove_specific_event_handler_multiple_handlers(driver): + """Test removing a specific event handler when multiple handlers exist.""" + events_received_1 = [] + events_received_2 = [] + + def on_context_created_1(info): + events_received_1.append(info) + + def on_context_created_2(info): + events_received_2.append(info) + + # Add multiple event handlers + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + + # Create a context to trigger both handlers + context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify both handlers received the event + assert len(events_received_1) >= 1 + assert len(events_received_2) >= 1 + + # store the initial event counts + initial_count_1 = len(events_received_1) + initial_count_2 = len(events_received_2) + + # Remove only the first handler + driver.browsing_context.remove_event_handler("context_created", callback_id_1) + + # Create another context + context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify only the second handler received the new event + assert len(events_received_1) == initial_count_1 # No new events + assert len(events_received_2) == initial_count_2 + 1 # 1 new event + + driver.browsing_context.close(context_id_1) + driver.browsing_context.close(context_id_2) + driver.browsing_context.remove_event_handler("context_created", callback_id_2) + + +class _EventHandlerTestHelper: + def __init__(self, driver): + self.driver = driver + self.events_received = [] + self.context_counts = {} + self.event_type_counts = {} + self.processing_times = [] + self.consistency_errors = [] + self.thread_errors = [] + self.callback_ids = [] + self.data_lock = threading.Lock() + self.registration_complete = threading.Event() + + def make_callback(self): + def callback(info): + start_time = time.time() + time.sleep(0.02) # Simulate race window + + with self.data_lock: + initial_event_count = len(self.events_received) + + self.events_received.append(info) + + context_id = info.context + self.context_counts.setdefault(context_id, 0) + self.context_counts[context_id] += 1 + + event_type = info.__class__.__name__ + self.event_type_counts.setdefault(event_type, 0) + self.event_type_counts[event_type] += 1 + + processing_time = time.time() - start_time + self.processing_times.append(processing_time) + + final_event_count = len(self.events_received) + final_context_total = sum(self.context_counts.values()) + final_type_total = sum(self.event_type_counts.values()) + final_processing_count = len(self.processing_times) + + expected_count = initial_event_count + 1 + if not ( + final_event_count + == final_context_total + == final_type_total + == final_processing_count + == expected_count + ): + self.consistency_errors.append("Data consistency error") + + return callback + + def register_handler(self, thread_id): + try: + callback = self.make_callback() + callback_id = self.driver.browsing_context.add_event_handler("context_created", callback) + with self.data_lock: + self.callback_ids.append(callback_id) + if len(self.callback_ids) == 5: + self.registration_complete.set() + return callback_id + except Exception as e: + with self.data_lock: + self.thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") + return None + + def remove_handler(self, callback_id, thread_id): + try: + self.driver.browsing_context.remove_event_handler("context_created", callback_id) + except Exception as e: + with self.data_lock: + self.thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") + + +def test_concurrent_event_handler_registration(driver): + helper = _EventHandlerTestHelper(driver) + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(helper.register_handler, f"reg-{i}") for i in range(5)] + for future in futures: + future.result(timeout=15) + + helper.registration_complete.wait(timeout=5) + assert len(helper.callback_ids) == 5, f"Expected 5 handlers, got {len(helper.callback_ids)}" + assert not helper.thread_errors, "Errors during registration: \n" + "\n".join(helper.thread_errors) + + +def test_event_callback_data_consistency(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + test_contexts = [] + for _ in range(3): + context = driver.browsing_context.create(type=WindowTypes.TAB) + test_contexts.append(context) + + for ctx in test_contexts: + driver.browsing_context.close(ctx) + + with helper.data_lock: + assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) + assert len(helper.events_received) > 0, "No events received" + assert len(helper.events_received) == sum(helper.context_counts.values()) + assert len(helper.events_received) == sum(helper.event_type_counts.values()) + assert len(helper.events_received) == len(helper.processing_times) + + +def test_concurrent_event_handler_removal(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit(helper.remove_handler, callback_id, f"rem-{i}") + for i, callback_id in enumerate(helper.callback_ids) + ] + for future in futures: + future.result(timeout=15) + + assert not helper.thread_errors, "Errors during removal: \n" + "\n".join(helper.thread_errors) + + +def test_no_event_after_handler_removal(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + context = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(context) + + events_before = len(helper.events_received) + + for i, callback_id in enumerate(helper.callback_ids): + helper.remove_handler(callback_id, f"rem-{i}") + + post_context = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(post_context) + + with helper.data_lock: + new_events = len(helper.events_received) - events_before + + assert new_events == 0, f"Expected 0 new events after removal, got {new_events}"
diff --git a/py/test/selenium/webdriver/common/bidi_input_tests.py b/py/test/selenium/webdriver/common/bidi_input_tests.py new file mode 100644 index 0000000..ecbe0bd --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_input_tests.py
@@ -0,0 +1,415 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tempfile +import time + +import pytest + +from selenium.webdriver.common.bidi.input import ( + ElementOrigin, + FileDialogInfo, + KeyDownAction, + KeySourceActions, + KeyUpAction, + Origin, + PauseAction, + PointerCommonProperties, + PointerDownAction, + PointerMoveAction, + PointerParameters, + PointerSourceActions, + PointerType, + PointerUpAction, + WheelScrollAction, + WheelSourceActions, +) +from selenium.webdriver.common.by import By +from selenium.webdriver.support.ui import WebDriverWait + + +def test_input_initialized(driver): + """Test that the input module is initialized properly.""" + assert driver.input is not None + + +def test_basic_key_input(driver, pages): + """Test basic keyboard input using BiDi.""" + pages.load("single_text_input.html") + + input_element = driver.find_element(By.ID, "textInput") + + # Create keyboard actions to type "hello" + key_actions = KeySourceActions( + id="keyboard", + actions=[ + KeyDownAction(value="h"), + KeyUpAction(value="h"), + KeyDownAction(value="e"), + KeyUpAction(value="e"), + KeyDownAction(value="l"), + KeyUpAction(value="l"), + KeyDownAction(value="l"), + KeyUpAction(value="l"), + KeyDownAction(value="o"), + KeyUpAction(value="o"), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [key_actions]) + + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "hello") + assert input_element.get_attribute("value") == "hello" + + +def test_key_input_with_pause(driver, pages): + """Test keyboard input with pause actions.""" + pages.load("single_text_input.html") + + input_element = driver.find_element(By.ID, "textInput") + + # Create keyboard actions with pauses + key_actions = KeySourceActions( + id="keyboard", + actions=[ + KeyDownAction(value="a"), + KeyUpAction(value="a"), + PauseAction(duration=100), + KeyDownAction(value="b"), + KeyUpAction(value="b"), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [key_actions]) + + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "ab") + assert input_element.get_attribute("value") == "ab" + + +def test_pointer_click(driver, pages): + """Test basic pointer click using BiDi.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + + # Get button location + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Create pointer actions for a click + pointer_actions = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=x, y=y), + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_pointer_move_with_element_origin(driver, pages): + """Test pointer move with element origin.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + + # Get element reference for BiDi + element_id = button.id + element_ref = {"sharedId": element_id} + element_origin = ElementOrigin(element_ref) + + # Create pointer actions with element origin + pointer_actions = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=0, y=0, origin=element_origin), + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_pointer_with_common_properties(driver, pages): + """Test pointer actions with common properties.""" + pages.load("javascriptPage.html") + + button = driver.find_element(By.ID, "clickField") + location = button.location + size = button.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Create pointer properties + properties = PointerCommonProperties( + width=2, height=2, pressure=0.5, tangential_pressure=0.0, twist=45, altitude_angle=0.5, azimuth_angle=1.0 + ) + + pointer_actions = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PointerMoveAction(x=x, y=y, properties=properties), + PointerDownAction(button=0, properties=properties), + PointerUpAction(button=0), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions]) + + WebDriverWait(driver, 5).until(lambda d: button.get_attribute("value") == "Clicked") + assert button.get_attribute("value") == "Clicked" + + +def test_wheel_scroll(driver, pages): + """Test wheel scroll actions.""" + # page that can be scrolled + pages.load("scroll3.html") + + # Scroll down + wheel_actions = WheelSourceActions( + id="wheel", actions=[WheelScrollAction(x=100, y=100, delta_x=0, delta_y=100, origin=Origin.VIEWPORT)] + ) + + driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) + + # Verify the page scrolled by checking scroll position + scroll_y = driver.execute_script("return window.pageYOffset;") + assert scroll_y == 100 + + +def test_combined_input_actions(driver, pages): + """Test combining multiple input sources.""" + pages.load("single_text_input.html") + + input_element = driver.find_element(By.ID, "textInput") + + # First click on the input field, then type + location = input_element.location + size = input_element.size + x = location["x"] + size["width"] // 2 + y = location["y"] + size["height"] // 2 + + # Pointer actions to click + pointer_actions = PointerSourceActions( + id="mouse", + parameters=PointerParameters(pointer_type=PointerType.MOUSE), + actions=[ + PauseAction(duration=0), # Sync with keyboard + PointerMoveAction(x=x, y=y), + PointerDownAction(button=0), + PointerUpAction(button=0), + ], + ) + + # Keyboard actions to type + key_actions = KeySourceActions( + id="keyboard", + actions=[ + PauseAction(duration=0), # Sync with pointer + # write "test" + KeyDownAction(value="t"), + KeyUpAction(value="t"), + KeyDownAction(value="e"), + KeyUpAction(value="e"), + KeyDownAction(value="s"), + KeyUpAction(value="s"), + KeyDownAction(value="t"), + KeyUpAction(value="t"), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [pointer_actions, key_actions]) + + WebDriverWait(driver, 5).until(lambda d: input_element.get_attribute("value") == "test") + assert input_element.get_attribute("value") == "test" + + +def test_set_files(driver, pages): + """Test setting files on file input element.""" + pages.load("formPage.html") + + upload_element = driver.find_element(By.ID, "upload") + assert upload_element.get_attribute("value") == "" + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as temp_file: + temp_file.write("test content") + temp_file_path = temp_file.name + + try: + # Get element reference for BiDi + element_id = upload_element.id + element_ref = {"sharedId": element_id} + + # Set files using BiDi + driver.input.set_files(driver.current_window_handle, element_ref, [temp_file_path]) + + # Verify file was set + value = upload_element.get_attribute("value") + assert os.path.basename(temp_file_path) in value + + finally: + # Clean up temp file + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +def test_set_multiple_files(driver): + """Test setting multiple files on a file input element with 'multiple' attribute using BiDi.""" + driver.get("data:text/html,<input id=upload type=file multiple />") + + upload_element = driver.find_element(By.ID, "upload") + + # Create temporary files + temp_files = [] + for i in range(2): + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) + temp_file.write(f"test content {i}") + temp_files.append(temp_file.name) + temp_file.close() + + try: + # Get element reference for BiDi + element_id = upload_element.id + element_ref = {"sharedId": element_id} + + driver.input.set_files(driver.current_window_handle, element_ref, temp_files) + + value = upload_element.get_attribute("value") + assert value != "" + + finally: + # Clean up temp files + for temp_file_path in temp_files: + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +def test_release_actions(driver, pages): + """Test releasing input actions.""" + pages.load("single_text_input.html") + + input_element = driver.find_element(By.ID, "textInput") + + # Perform some actions first + key_actions = KeySourceActions( + id="keyboard", + actions=[ + KeyDownAction(value="a"), + # Note: not releasing the key + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [key_actions]) + + # Now release all actions + driver.input.release_actions(driver.current_window_handle) + + # The key should be released now, so typing more should work normally + key_actions2 = KeySourceActions( + id="keyboard", + actions=[ + KeyDownAction(value="b"), + KeyUpAction(value="b"), + ], + ) + + driver.input.perform_actions(driver.current_window_handle, [key_actions2]) + + # Should be able to type normally + WebDriverWait(driver, 5).until(lambda d: "b" in input_element.get_attribute("value")) + + [email protected]("multiple", [True, False]) [email protected]_firefox(reason="File dialog handling not implemented in Firefox yet") +def test_file_dialog_event_handler_multiple(driver, multiple): + """Test file dialog event handler with multiple as true and false.""" + file_dialog_events = [] + + def file_dialog_handler(file_dialog_info): + file_dialog_events.append(file_dialog_info) + + # Test event handler registration + handler_id = driver.input.add_file_dialog_handler(file_dialog_handler) + assert handler_id is not None + + driver.get(f"data:text/html,<input id=upload type=file {'multiple' if multiple else ''} />") + + # Use script.evaluate to trigger the file dialog with user activation + driver.script._evaluate( + expression="document.getElementById('upload').click()", + target={"context": driver.current_window_handle}, + await_promise=False, + user_activation=True, + ) + + # Wait for the file dialog event to be triggered + WebDriverWait(driver, 5).until(lambda d: len(file_dialog_events) > 0) + + assert len(file_dialog_events) > 0 + file_dialog_info = file_dialog_events[0] + assert isinstance(file_dialog_info, FileDialogInfo) + assert file_dialog_info.context == driver.current_window_handle + # Check if multiple attribute is set correctly (True, False) + assert file_dialog_info.multiple is multiple + + driver.input.remove_file_dialog_handler(handler_id) + + [email protected]_firefox(reason="File dialog handling not implemented in Firefox yet") +def test_file_dialog_event_handler_unsubscribe(driver): + """Test file dialog event handler unsubscribe.""" + file_dialog_events = [] + + def file_dialog_handler(file_dialog_info): + file_dialog_events.append(file_dialog_info) + + # Register the handler + handler_id = driver.input.add_file_dialog_handler(file_dialog_handler) + assert handler_id is not None + + # Unsubscribe the handler + driver.input.remove_file_dialog_handler(handler_id) + + driver.get("data:text/html,<input id=upload type=file />") + + # Trigger the file dialog + driver.script._evaluate( + expression="document.getElementById('upload').click()", + target={"context": driver.current_window_handle}, + await_promise=False, + user_activation=True, + ) + + # Wait to ensure no events are captured + time.sleep(1) + assert len(file_dialog_events) == 0
diff --git a/rb/spec/integration/selenium/webdriver/network_spec.rb b/rb/spec/integration/selenium/webdriver/network_spec.rb index b640644..ae1b2bf 100644 --- a/rb/spec/integration/selenium/webdriver/network_spec.rb +++ b/rb/spec/integration/selenium/webdriver/network_spec.rb
@@ -21,7 +21,7 @@ module Selenium module WebDriver - describe Network, exclude: {version: GlobalTestEnv.beta_chrome_version}, + describe Network, exclude: {version: 'beta'}, exclusive: {bidi: true, reason: 'only executed when bidi is enabled'}, only: {browser: %i[chrome edge firefox]} do let(:username) { SpecSupport::RackServer::TestApp::BASIC_AUTH_CREDENTIALS.first }
diff --git a/rb/spec/integration/selenium/webdriver/spec_support/test_environment.rb b/rb/spec/integration/selenium/webdriver/spec_support/test_environment.rb index 72d4d29..fa02021 100644 --- a/rb/spec/integration/selenium/webdriver/spec_support/test_environment.rb +++ b/rb/spec/integration/selenium/webdriver/spec_support/test_environment.rb
@@ -58,7 +58,7 @@ end def browser_version - driver_instance.capabilities.browser_version + ENV.fetch('WD_BROWSER_VERSION', 'stable') end def driver_instance(...) @@ -193,18 +193,6 @@ raise e end - def beta_chrome_version - chrome_beta_url = 'https://chromereleases.googleblog.com/search/label/Beta%20updates' - - uri = URI.parse(chrome_beta_url) - - response = Net::HTTP.get_response(uri) - - return "Failed to fetch Chrome Beta page: #{response&.code}" unless response.is_a?(Net::HTTPSuccess) - - response.body.match(/\d+\.\d+\.\d+\.\d+/).to_s - end - private def build_options(**)
diff --git a/rb/spec/tests.bzl b/rb/spec/tests.bzl index 85b7cb9..accd038 100644 --- a/rb/spec/tests.bzl +++ b/rb/spec/tests.bzl
@@ -111,6 +111,7 @@ "env": { "WD_REMOTE_BROWSER": "firefox", "WD_SPEC_DRIVER": "firefox", + "WD_BROWSER_VERSION": "beta", } | select({ "@selenium//common:use_pinned_linux_firefox": { "FIREFOX_BINARY": "$(location @linux_beta_firefox//:firefox/firefox)",