blob: 53af0294a46380c1d954a8309555f99a7c50f1a2 [file]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef SERVICES_AUDIO_ML_MODEL_MANAGER_H_
#define SERVICES_AUDIO_ML_MODEL_MANAGER_H_
#include <optional>
#include "base/files/file.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/thread_annotations.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "services/audio/public/mojom/ml_model_manager.mojom.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/tflite/src/tensorflow/lite/model_builder.h"
namespace audio {
struct ModelWithBuffer;
class MlModelHandle {
public:
virtual ~MlModelHandle() = default;
// Returns a pointer to a TFLite model file, valid for the lifetime of the
// Modelhandle.
virtual const tflite::FlatBufferModel* Get() = 0;
};
// Interface for providing Machine Learning models within the audio service.
// This interface is used by components like the AudioProcessorHandler to access
// // ML model information.
class MlModelManager {
public:
virtual ~MlModelManager() = default;
// Returns a handle for a TFLite model file for the neural residual echo
// estimator. Returns nullptr if no valid model file is available.
// The model shares its lifetime with the MlModelManager, and the client must
// stop using and destroy the handle within that lifetime.
virtual std::unique_ptr<MlModelHandle> GetResidualEchoEstimationModel() = 0;
};
// Implementation of the MlModelManager interface. This class receives model
// information from the browser process through the mojom::MlModelManager Mojo
// interface and provides it to audio service components.
//
// Current Behavior:
// - Model files provided via SetResidualEchoEstimationModel() are loaded and
// cached for as long as they are used or may be used.
// - GetResidualEchoEstimationModel() returns the last successfully loaded
// model.
// - StopServingResidualEchoEstimationModel() stops serving all previously set
// models, but the models are kept alive for as long as existing clients need
// them.
class MlModelManagerImpl : public MlModelManager, public mojom::MlModelManager {
public:
MlModelManagerImpl();
~MlModelManagerImpl() override;
MlModelManagerImpl(const MlModelManagerImpl&) = delete;
MlModelManagerImpl& operator=(const MlModelManagerImpl&) = delete;
void BindReceiver(mojo::PendingReceiver<mojom::MlModelManager> receiver);
// mojom::MlModelManager implementation.
void SetResidualEchoEstimationModel(base::File tflite_file) override;
void StopServingResidualEchoEstimationModel() override;
// MlModelManager implementation.
std::unique_ptr<MlModelHandle> GetResidualEchoEstimationModel() override;
// Stops any ongoing loading of model files.
void CancelModelLoadingTasks();
bool HasPendingTasksForTesting() const;
private:
// Callback for MlModelHandle to track active clients for a model.
void OnModelHandleDestruction(ModelWithBuffer* model_with_buffer);
// Continuation for SetResidualEchoEstimationModel after work on background
// thread after work on background thread.
void OnResidualEchoEstimationModelRead(
std::unique_ptr<ModelWithBuffer> model_with_buffer);
SEQUENCE_CHECKER(sequence_checker_);
std::optional<mojo::Receiver<mojom::MlModelManager>> receiver_
GUARDED_BY_CONTEXT(sequence_checker_);
// The service can hold on to multiple models simultaneously. The models
// travel between three storages:
//
// 1. `unused_serving_model_`: A newly loaded model is placed here. It's
// available for use but hasn't been requested yet.
//
// 2. `used_serving_model_`: When a client requests a model via
// `GetResidualEchoEstimationModel()`, the model is moved from
// `unused_serving_model_` to here. This indicates the model is actively
// in use. If all clients stop using it, it moves back to
// `unused_serving_model_`.
//
// 3. `retired_models_`: If a new model is loaded while the current one in
// `used_serving_model_` is still being used, the `used_serving_model_`
// is moved into this map. This keeps the model alive for existing clients
// while allowing new clients to use the new model.
//
// At most one of `unused_serving_model_` and `used_serving_model_` is
// non-null at any given time.
std::unique_ptr<ModelWithBuffer> unused_serving_model_
GUARDED_BY_CONTEXT(sequence_checker_);
std::unique_ptr<ModelWithBuffer> used_serving_model_
GUARDED_BY_CONTEXT(sequence_checker_);
// Maps ModelWithBuffer pointers to the respective full ModelWithBuffer
// object, for easy lookup.
absl::flat_hash_map<ModelWithBuffer*, std::unique_ptr<ModelWithBuffer>>
retired_models_ GUARDED_BY_CONTEXT(sequence_checker_);
base::WeakPtrFactory<MlModelManagerImpl> weak_factory_{this};
};
} // namespace audio
#endif // SERVICES_AUDIO_ML_MODEL_MANAGER_H_