blob: 88b890798e104e1051e417e268601219bdaf6e48 [file]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/network/websocket_factory.h"
#include <memory>
#include <string>
#include <vector>
#include "base/functional/callback.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/system/functions.h"
#include "net/base/isolation_info.h"
#include "net/log/net_log.h"
#include "net/log/net_log_entry.h"
#include "net/log/net_log_event_type.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "services/network/network_context.h"
#include "services/network/network_service.h"
#include "services/network/public/mojom/network_context.mojom.h"
#include "services/network/public/mojom/websocket.mojom.h"
#include "services/network/test/fake_test_cert_verifier_params_factory.h"
#include "services/network/test/test_network_context_client.h"
#include "services/network/test/test_utils.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
#include "url/origin.h"
namespace network {
namespace {
// Simple NetLog observer that collects entries for testing.
// Unlike RecordingNetLogObserver, this does not auto-register with a NetLog
// instance, so it can be used with CreateNetLogEntriesForActiveConnections()
// which calls OnAddEntry() directly.
class TestNetLogObserver : public net::NetLog::ThreadSafeObserver {
public:
TestNetLogObserver() = default;
~TestNetLogObserver() override = default;
void OnAddEntry(const net::NetLogEntry& entry) override {
entries_.push_back(entry.Clone());
}
const std::vector<net::NetLogEntry>& entries() const { return entries_; }
private:
std::vector<net::NetLogEntry> entries_;
};
class StubWebSocketHandshakeClient : public mojom::WebSocketHandshakeClient {
public:
explicit StubWebSocketHandshakeClient(
mojo::PendingReceiver<mojom::WebSocketHandshakeClient> receiver)
: receiver_(this, std::move(receiver)) {}
~StubWebSocketHandshakeClient() override = default;
void OnOpeningHandshakeStarted(mojom::WebSocketHandshakeRequestPtr) override {
}
void OnConnectionEstablished(mojo::PendingRemote<mojom::WebSocket>,
mojo::PendingReceiver<mojom::WebSocketClient>,
mojom::WebSocketHandshakeResponsePtr,
mojo::ScopedDataPipeConsumerHandle,
mojo::ScopedDataPipeProducerHandle) override {}
void OnFailure(const std::string&, int32_t, int32_t) override {}
private:
mojo::Receiver<mojom::WebSocketHandshakeClient> receiver_;
};
mojom::NetworkContextParamsPtr CreateNetworkContextParams() {
auto params = CreateNetworkContextParamsForTesting();
// Use a dummy CertVerifier that always passes cert verification, since
// these unittests don't need to test CertVerifier behavior.
params->cert_verifier_params =
FakeTestCertVerifierParamsFactory::GetCertVerifierParams();
return params;
}
class WebSocketFactoryTest : public testing::Test {
public:
WebSocketFactoryTest()
: task_environment_(base::test::TaskEnvironment::MainThreadType::IO),
network_service_(NetworkService::CreateForTesting()) {
network_context_ = std::make_unique<NetworkContext>(
network_service_.get(),
network_context_remote_.BindNewPipeAndPassReceiver(),
CreateNetworkContextParams(), base::DoNothing());
factory_ = std::make_unique<WebSocketFactory>(network_context_.get());
}
void CreateWebSocket(const GURL& url,
const std::vector<std::string>& requested_protocols) {
mojo::PendingRemote<mojom::WebSocketHandshakeClient> handshake_client;
// Keep all handshake clients alive to prevent WebSocket cleanup
stub_handshake_clients_.push_back(
std::make_unique<StubWebSocketHandshakeClient>(
handshake_client.InitWithNewPipeAndPassReceiver()));
// WebSocket objects are owned by the factory and will be deleted
// asynchronously.
factory_->CreateWebSocket(
url, requested_protocols, net::StorageAccessApiStatus::kNone,
net::IsolationInfo(), {}, network::OriginatingProcessId::browser(),
url::Origin::Create(url),
/*client_security_state=*/nullptr, /*options=*/0,
TRAFFIC_ANNOTATION_FOR_TESTS, std::move(handshake_client),
mojo::NullRemote(), mojo::NullRemote(), mojo::NullRemote(),
/*throttling_profile_id=*/std::nullopt,
/*network_restrictions_id=*/std::nullopt);
}
protected:
base::test::TaskEnvironment task_environment_;
std::unique_ptr<NetworkService> network_service_;
mojo::Remote<mojom::NetworkContext> network_context_remote_;
std::unique_ptr<NetworkContext> network_context_;
std::unique_ptr<WebSocketFactory> factory_;
std::vector<std::unique_ptr<StubWebSocketHandshakeClient>>
stub_handshake_clients_;
};
class WebSocketFactoryBadMessageTest : public WebSocketFactoryTest {
public:
WebSocketFactoryBadMessageTest()
: dummy_message_(0, 0, 0, 0, nullptr), context_(&dummy_message_) {
mojo::SetDefaultProcessErrorHandler(
future_.GetRepeatingCallback<const std::string&>());
}
~WebSocketFactoryBadMessageTest() override {
mojo::SetDefaultProcessErrorHandler(base::NullCallback());
}
std::string GetBadMessage() { return future_.Get(); }
private:
base::test::TestFuture<std::string> future_;
// Because we don't really call the factory from mojo, we need to create a
// fake message context.
mojo::Message dummy_message_;
mojo::internal::MessageDispatchContext context_;
};
TEST_F(WebSocketFactoryBadMessageTest, InvalidScheme) {
CreateWebSocket(GURL("http://example.com/"), {});
EXPECT_EQ(GetBadMessage(), "Invalid scheme.");
}
TEST_F(WebSocketFactoryBadMessageTest, InvalidSubprotocol) {
CreateWebSocket(GURL("ws://example.com/"), {"invalid protocol"});
EXPECT_EQ(GetBadMessage(), "Invalid protocols.");
}
TEST_F(WebSocketFactoryBadMessageTest, EmptySubprotocol) {
CreateWebSocket(GURL("ws://example.com/"), {"okay", "", "okay2"});
EXPECT_EQ(GetBadMessage(), "Invalid protocols.");
}
TEST_F(WebSocketFactoryBadMessageTest, DuplicateSubprotocol) {
CreateWebSocket(GURL("ws://example.com/"), {"foo", "foo"});
EXPECT_EQ(GetBadMessage(), "Invalid protocols.");
}
// Tests for CreateNetLogEntriesForActiveConnections()
TEST_F(WebSocketFactoryTest, CreateNetLogEntriesEmptyFactory) {
// An empty factory should produce no NetLog entries.
TestNetLogObserver observer;
factory_->CreateNetLogEntriesForActiveConnections(&observer);
EXPECT_TRUE(observer.entries().empty());
}
TEST_F(WebSocketFactoryTest, CreateNetLogEntriesForActiveConnections) {
// Create multiple connections to verify entries are produced for each.
CreateWebSocket(GURL("wss://example.com/chat1"), {});
CreateWebSocket(GURL("wss://example.com/chat2"), {});
TestNetLogObserver observer;
factory_->CreateNetLogEntriesForActiveConnections(&observer);
// Each connection with a channel should produce one WEBSOCKET_ALIVE entry.
EXPECT_EQ(observer.entries().size(), 2u);
for (const auto& entry : observer.entries()) {
EXPECT_EQ(entry.type, net::NetLogEventType::WEBSOCKET_ALIVE);
EXPECT_EQ(entry.phase, net::NetLogEventPhase::BEGIN);
EXPECT_TRUE(entry.HasParams());
}
}
} // namespace
} // namespace network