[grid] Fix latent bugs in WebSocket proxy (#17429)
* [grid] Forward orphan WebSocket frames inbound, not outbound
MessageInboundConverter handled orphan continuation frames and unknown
frame types by calling ctx.write(frame), which sends the frame back out
to the peer. They should travel forward through the inbound pipeline so
the upgrade/keepalive handlers can deal with them.
* [grid] Propagate ChannelPromise through MessageOutboundConverter
The converter was calling ctx.writeAndFlush(frame) and dropping the
incoming ChannelPromise, so any caller awaiting the original future
would never see it complete. Pipe the promise through ctx.write(...,
promise) and rely on the upstream writeAndFlush() to trigger the flush.
* [grid] Size BinaryMessage by readable bytes, not buffer capacity
BinaryMessage(ByteBuffer) sized its array by capacity(), which works
only because Netty's ByteBuf.nioBuffer() happens to return a slice
where capacity == remaining. A caller passing a flipped ByteBuffer
backed by a larger array would either pad zero bytes onto the message
or hit a BufferUnderflowException. Use remaining() so the contract
matches the comment "data to use".
* [grid] Latch upstreamClosing when WebSocketFrameProxy send fails
A failing forward currently fires the exception through the pipeline
but leaves upstreamClosing unset, so any frames already queued behind
this one re-attempt the same failing send before the close handshake
runs. Set the flag on first failure so subsequent frames short-circuit
to the drop path immediately.
* [grid] Release WebSocket consumer when handshake fails
The factory passed to WebSocketUpgradeHandler may have already opened
an upstream WebSocket and acquired a connection slot by the time the
WS handshake itself fails (unsupported Sec-WebSocket-Version, or the
handshake future completing exceptionally). The previous code dropped
the produced consumer without invoking it, so the upstream and the
slot leaked. Drive each failure path through the consumer's CloseMessage
cleanup so existing consumer logic frees the resources.
---------
Co-authored-by: Diego Molina <[email protected]>
diff --git a/java/src/org/openqa/selenium/netty/server/MessageInboundConverter.java b/java/src/org/openqa/selenium/netty/server/MessageInboundConverter.java
index 3773629..61a8439 100644
--- a/java/src/org/openqa/selenium/netty/server/MessageInboundConverter.java
+++ b/java/src/org/openqa/selenium/netty/server/MessageInboundConverter.java
@@ -90,7 +90,7 @@ protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
}
break;
case None:
- ctx.write(frame);
+ ctx.fireChannelRead(frame.retain());
return;
default:
throw new IllegalStateException("unexpected enum: " + next);
@@ -128,7 +128,7 @@ protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame;
message = new CloseMessage(closeFrame.statusCode(), closeFrame.reasonText());
} else {
- ctx.write(frame);
+ ctx.fireChannelRead(frame.retain());
return;
}
diff --git a/java/src/org/openqa/selenium/netty/server/MessageOutboundConverter.java b/java/src/org/openqa/selenium/netty/server/MessageOutboundConverter.java
index 0dcfab0..99e7462 100644
--- a/java/src/org/openqa/selenium/netty/server/MessageOutboundConverter.java
+++ b/java/src/org/openqa/selenium/netty/server/MessageOutboundConverter.java
@@ -46,14 +46,15 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
if (seMessage instanceof CloseMessage) {
CloseMessage closeMessage = (CloseMessage) seMessage;
- ctx.writeAndFlush(
- new CloseWebSocketFrame(true, 0, closeMessage.code(), closeMessage.reason()));
+ ctx.write(
+ new CloseWebSocketFrame(true, 0, closeMessage.code(), closeMessage.reason()), promise);
} else if (seMessage instanceof BinaryMessage) {
- ctx.writeAndFlush(
+ ctx.write(
new BinaryWebSocketFrame(
- true, 0, Unpooled.copiedBuffer(((BinaryMessage) seMessage).data())));
+ true, 0, Unpooled.copiedBuffer(((BinaryMessage) seMessage).data())),
+ promise);
} else if (seMessage instanceof TextMessage) {
- ctx.writeAndFlush(new TextWebSocketFrame(true, 0, ((TextMessage) seMessage).text()));
+ ctx.write(new TextWebSocketFrame(true, 0, ((TextMessage) seMessage).text()), promise);
} else {
LOG.warning(String.format("Unable to handle %s", msg));
super.write(ctx, msg, promise);
diff --git a/java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java b/java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java
index aad98b1..93a8a35 100644
--- a/java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java
+++ b/java/src/org/openqa/selenium/netty/server/WebSocketFrameProxy.java
@@ -89,6 +89,9 @@ protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
try {
forwardFrame(frame);
} catch (Exception e) {
+ // Mark the upstream as closing so the next frame on this connection short-circuits
+ // rather than retrying the same failing send while the close handshake runs.
+ upstreamClosing.set(true);
LOG.log(Level.WARNING, "Failed to forward WebSocket frame to node", e);
ctx.fireExceptionCaught(e);
}
diff --git a/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java b/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java
index 109bd03..ed37172 100644
--- a/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java
+++ b/java/src/org/openqa/selenium/netty/server/WebSocketUpgradeHandler.java
@@ -98,6 +98,15 @@ private static String getWebSocketLocation(HttpRequest req) {
return "ws://" + req.headers().get(HttpHeaderNames.HOST);
}
+ private static void releaseHandlerOnHandshakeFailure(
+ Consumer<Message> handler, int code, String reason) {
+ try {
+ handler.accept(new CloseMessage(code, reason));
+ } catch (Exception ex) {
+ LOG.log(Level.FINE, "failed to release handler on handshake failure", ex);
+ }
+ }
+
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HttpRequest) {
@@ -152,6 +161,10 @@ private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) {
getWebSocketLocation(req), null, true, Integer.MAX_VALUE);
handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
+ // The factory has already opened the upstream and (on the Node) acquired a connection
+ // slot. Drive the consumer through its CloseMessage cleanup path before we send the
+ // unsupported-version response, otherwise the upstream and the slot leak.
+ releaseHandlerOnHandshakeFailure(maybeHandler.get(), 1002, "unsupported websocket version");
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
ChannelFuture future = handshaker.handshake(ctx.channel(), req);
@@ -159,6 +172,11 @@ private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) {
(ChannelFutureListener)
channelFuture -> {
if (!future.isSuccess()) {
+ // Same leak path: the consumer was never registered in the channel attr,
+ // so the generic exceptionCaught handler will not see it. Drive cleanup
+ // here so the upstream and any acquired slot are released.
+ releaseHandlerOnHandshakeFailure(
+ maybeHandler.get(), 1011, "websocket handshake failed");
ctx.fireExceptionCaught(future.cause());
} else {
Consumer<Message> handler = maybeHandler.get();
diff --git a/java/src/org/openqa/selenium/remote/http/BinaryMessage.java b/java/src/org/openqa/selenium/remote/http/BinaryMessage.java
index 7a62748..e0ba78b 100644
--- a/java/src/org/openqa/selenium/remote/http/BinaryMessage.java
+++ b/java/src/org/openqa/selenium/remote/http/BinaryMessage.java
@@ -26,7 +26,7 @@ public class BinaryMessage implements Message {
public BinaryMessage(ByteBuffer data) {
ByteBuffer copy = Require.nonNull("Data to use", data).asReadOnlyBuffer();
- this.data = new byte[copy.capacity()];
+ this.data = new byte[copy.remaining()];
copy.get(this.data);
}
diff --git a/java/test/org/openqa/selenium/netty/server/BUILD.bazel b/java/test/org/openqa/selenium/netty/server/BUILD.bazel
index d6bfcef..ab63f94 100644
--- a/java/test/org/openqa/selenium/netty/server/BUILD.bazel
+++ b/java/test/org/openqa/selenium/netty/server/BUILD.bazel
@@ -2,7 +2,11 @@
load("//java:defs.bzl", "JUNIT5_DEPS", "java_library", "java_test_suite")
SMALL_TEST_SRCS = [
+ "MessageInboundConverterTest.java",
+ "MessageOutboundConverterTest.java",
"RequestConverterTest.java",
+ "WebSocketFrameProxyTest.java",
+ "WebSocketUpgradeHandlerTest.java",
]
java_test_suite(
@@ -11,7 +15,10 @@
srcs = SMALL_TEST_SRCS,
deps = [
"//java/src/org/openqa/selenium/netty/server",
+ "//java/src/org/openqa/selenium/remote/http",
+ artifact("io.netty:netty-buffer"),
artifact("io.netty:netty-codec-http"),
+ artifact("io.netty:netty-common"),
artifact("io.netty:netty-transport"),
artifact("org.junit.jupiter:junit-jupiter-api"),
artifact("org.assertj:assertj-core"),
diff --git a/java/test/org/openqa/selenium/netty/server/MessageInboundConverterTest.java b/java/test/org/openqa/selenium/netty/server/MessageInboundConverterTest.java
new file mode 100644
index 0000000..569c286
--- /dev/null
+++ b/java/test/org/openqa/selenium/netty/server/MessageInboundConverterTest.java
@@ -0,0 +1,64 @@
+// Licensed to the Software Freedom Conservancy (SFC) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The SFC licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.openqa.selenium.netty.server;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import io.netty.buffer.Unpooled;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
+import org.junit.jupiter.api.Test;
+
+class MessageInboundConverterTest {
+
+ @Test
+ void orphanContinuationFrameIsForwardedInbound() {
+ EmbeddedChannel channel = new EmbeddedChannel(new MessageInboundConverter());
+
+ // No prior Text/Binary frame has set up a continuation context, so this is an orphan.
+ ContinuationWebSocketFrame orphan =
+ new ContinuationWebSocketFrame(true, 0, Unpooled.wrappedBuffer("x".getBytes()));
+
+ assertThat(channel.writeInbound(orphan)).isTrue();
+
+ // The frame should travel forward through the pipeline, NOT back out as a write.
+ Object inbound = channel.readInbound();
+ assertThat(inbound).isInstanceOf(ContinuationWebSocketFrame.class);
+ assertThat(channel.outboundMessages()).isEmpty();
+
+ ((ContinuationWebSocketFrame) inbound).release();
+ }
+
+ @Test
+ void unknownFrameTypeIsForwardedInbound() {
+ EmbeddedChannel channel = new EmbeddedChannel(new MessageInboundConverter());
+
+ // Ping is not a frame type the converter handles; it should be passed forward
+ // for the keepalive/upgrade handler to deal with, not echoed back to the peer.
+ PingWebSocketFrame ping = new PingWebSocketFrame();
+
+ assertThat(channel.writeInbound(ping)).isTrue();
+
+ Object inbound = channel.readInbound();
+ assertThat(inbound).isInstanceOf(PingWebSocketFrame.class);
+ assertThat(channel.outboundMessages()).isEmpty();
+
+ ((PingWebSocketFrame) inbound).release();
+ }
+}
diff --git a/java/test/org/openqa/selenium/netty/server/MessageOutboundConverterTest.java b/java/test/org/openqa/selenium/netty/server/MessageOutboundConverterTest.java
new file mode 100644
index 0000000..24df1b7
--- /dev/null
+++ b/java/test/org/openqa/selenium/netty/server/MessageOutboundConverterTest.java
@@ -0,0 +1,47 @@
+// Licensed to the Software Freedom Conservancy (SFC) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The SFC licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.openqa.selenium.netty.server;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import org.junit.jupiter.api.Test;
+import org.openqa.selenium.remote.http.TextMessage;
+
+class MessageOutboundConverterTest {
+
+ @Test
+ void promiseIsCompletedWhenMessageIsWritten() {
+ EmbeddedChannel channel = new EmbeddedChannel(new MessageOutboundConverter());
+
+ ChannelFuture future = channel.writeAndFlush(new TextMessage("hello"));
+
+ // The frame should land on the wire and the upstream future must complete:
+ // EmbeddedChannel runs synchronously, so success means our promise was honoured
+ // rather than dropped in favour of a fresh writeAndFlush() promise.
+ assertThat(future.isSuccess()).isTrue();
+
+ Object outbound = channel.readOutbound();
+ assertThat(outbound).isInstanceOf(TextWebSocketFrame.class);
+ TextWebSocketFrame frame = (TextWebSocketFrame) outbound;
+ assertThat(frame.text()).isEqualTo("hello");
+ frame.release();
+ }
+}
diff --git a/java/test/org/openqa/selenium/netty/server/WebSocketFrameProxyTest.java b/java/test/org/openqa/selenium/netty/server/WebSocketFrameProxyTest.java
new file mode 100644
index 0000000..ce7e404
--- /dev/null
+++ b/java/test/org/openqa/selenium/netty/server/WebSocketFrameProxyTest.java
@@ -0,0 +1,57 @@
+// Licensed to the Software Freedom Conservancy (SFC) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The SFC licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.openqa.selenium.netty.server;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.junit.jupiter.api.Test;
+import org.openqa.selenium.remote.http.Message;
+import org.openqa.selenium.remote.http.WebSocket;
+
+class WebSocketFrameProxyTest {
+
+ @Test
+ void marksUpstreamClosingWhenForwardingFails() {
+ WebSocket throwingUpstream =
+ new WebSocket() {
+ @Override
+ public WebSocket send(Message message) {
+ throw new RuntimeException("upstream gone");
+ }
+
+ @Override
+ public void close() {}
+ };
+
+ AtomicBoolean upstreamClosing = new AtomicBoolean(false);
+ EmbeddedChannel channel =
+ new EmbeddedChannel(new WebSocketFrameProxy(throwingUpstream, upstreamClosing));
+
+ try {
+ channel.writeInbound(new TextWebSocketFrame("hi"));
+ } catch (RuntimeException expected) {
+ // The proxy fires the exception through the pipeline; EmbeddedChannel rethrows it.
+ }
+
+ // First failure should latch upstreamClosing so subsequent frames short-circuit.
+ assertThat(upstreamClosing.get()).isTrue();
+ }
+}
diff --git a/java/test/org/openqa/selenium/netty/server/WebSocketUpgradeHandlerTest.java b/java/test/org/openqa/selenium/netty/server/WebSocketUpgradeHandlerTest.java
new file mode 100644
index 0000000..46dd04e
--- /dev/null
+++ b/java/test/org/openqa/selenium/netty/server/WebSocketUpgradeHandlerTest.java
@@ -0,0 +1,63 @@
+// Licensed to the Software Freedom Conservancy (SFC) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The SFC licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.openqa.selenium.netty.server;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.DefaultFullHttpRequest;
+import io.netty.handler.codec.http.FullHttpRequest;
+import io.netty.handler.codec.http.HttpHeaderNames;
+import io.netty.handler.codec.http.HttpMethod;
+import io.netty.handler.codec.http.HttpServerCodec;
+import io.netty.handler.codec.http.HttpVersion;
+import io.netty.util.AttributeKey;
+import java.util.Optional;
+import java.util.concurrent.atomic.AtomicReference;
+import org.junit.jupiter.api.Test;
+import org.openqa.selenium.remote.http.CloseMessage;
+import org.openqa.selenium.remote.http.Message;
+
+class WebSocketUpgradeHandlerTest {
+
+ @Test
+ void unsupportedWebSocketVersionDrivesConsumerCleanup() {
+ AtomicReference<Message> receivedByConsumer = new AtomicReference<>();
+
+ AttributeKey<java.util.function.Consumer<Message>> key =
+ AttributeKey.valueOf("ws-upgrade-handler-test");
+ WebSocketUpgradeHandler handler =
+ new WebSocketUpgradeHandler(key, (uri, downstream) -> Optional.of(receivedByConsumer::set));
+
+ EmbeddedChannel channel = new EmbeddedChannel(new HttpServerCodec(), handler);
+
+ FullHttpRequest req =
+ new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/session/abc/se/bidi");
+ req.headers().set(HttpHeaderNames.HOST, "localhost");
+ req.headers().set(HttpHeaderNames.CONNECTION, "Upgrade");
+ req.headers().set(HttpHeaderNames.UPGRADE, "websocket");
+ // Version 99 is not in Netty's supported set, so newHandshaker() returns null.
+ req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "99");
+ req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==");
+
+ channel.writeInbound(req);
+
+ // Without the cleanup, the consumer is dropped and any acquired resources leak.
+ assertThat(receivedByConsumer.get()).isInstanceOf(CloseMessage.class);
+ }
+}
diff --git a/java/test/org/openqa/selenium/remote/http/BinaryMessageTest.java b/java/test/org/openqa/selenium/remote/http/BinaryMessageTest.java
new file mode 100644
index 0000000..0f5a833
--- /dev/null
+++ b/java/test/org/openqa/selenium/remote/http/BinaryMessageTest.java
@@ -0,0 +1,38 @@
+// Licensed to the Software Freedom Conservancy (SFC) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The SFC licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.openqa.selenium.remote.http;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.nio.ByteBuffer;
+import org.junit.jupiter.api.Test;
+
+class BinaryMessageTest {
+
+ @Test
+ void copiesOnlyTheReadableRegionOfABuffer() {
+ // Backing array is 16 bytes but only the slice [4..8) is readable.
+ byte[] backing = {0, 0, 0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0};
+ ByteBuffer buffer = ByteBuffer.wrap(backing);
+ buffer.position(4).limit(8);
+
+ BinaryMessage message = new BinaryMessage(buffer);
+
+ assertThat(message.data()).containsExactly(1, 2, 3, 4);
+ }
+}