Skip to content

Results are limited to the current section : Qoolqit

Defining custom embedders

The BaseEmbedder is the abstract base class for all embedders, but it is defined on generic input and output data types. It enforces the interface for all embedders by defining the info and config properties, as well as the embed method. It also defines abstract validate_input and validate_output methods that must be overwritten by subclasses.

The next level is to define the concrete data types involved in the mapping, thus defining a family of embedders. Currently, there are two families of embedders defined in QoolQit:

  • GraphToGraphEmbedder which concretizes the BaseEmbedder with a DataGraph input type and a DataGraph output type.
  • MatrixToGraphEmbedder which concretizes the BaseEmbedder with a np.ndarray input type and a DataGraph output type.

In both cases, the validate_input and validate_output are overridden to check the input and output are of the correct type. In the case of the MatrixToGraphEmbedder, conditions on the input matrix are also checked such as if the array has the right dimensions and is symmetric. Still, at this level, no specific embedding algorithm is defined.

In the future, more families of embedders can be defined that may require different input and output data types.

Level 2: Concretizing the algorithm and config

Section titled “Level 2: Concretizing the algorithm and config”

The final level is defining concrete embedders, such as the ones we have used in the available embedders page. Here the requirement is to define a concrete function that maps the input to the output, along with any parameters required, and a config dataclass inheriting from EmbedderConfig holding all the configuration parameters.

from dataclasses import dataclass
from qoolqit import DataGraph
from qoolqit.embedding import EmbedderConfig, GraphToGraphEmbedder
def my_embedding_function(graph: DataGraph, param1: float) -> DataGraph:
"""Some embedding function that manipulates the input graph.
This docstring should be clear on the embedding logic, because it will be
directly accessed by the embedder.info property.
Arguments:
param1: a useless parameter...
"""
return graph
@dataclass
class MyEmbedderConfig(EmbedderConfig):
param1: float = 1.0
embedder = GraphToGraphEmbedder(my_embedding_function, MyEmbedderConfig())
print(embedder)
print(embedder.info)
embedder.config.param1 = 2.0
graph = DataGraph.random_er(5, 0.5)
embedded_graph = embedder.embed(graph)
class MyNewEmbedder(GraphToGraphEmbedder):
def __init__(self):
super().__init__(my_embedding_function, MyEmbedderConfig())

To define a custom embedder, the extra arguments in the embedding function (besides the data) must match the fields in the configuration dataclass, otherwise an error will be raised.

def my_embedding_function(graph: DataGraph, param1: float) -> DataGraph:
return graph
@dataclass
class MyWrongConfig(EmbedderConfig):
some_other_param: float = 1.0
try:
wrong_embedder = GraphToGraphEmbedder(my_embedding_function, MyWrongConfig())
except TypeError as error:
print(error)
embedder = GraphToGraphEmbedder(my_embedding_function, MyEmbedderConfig())
try:
data = 1.0 # Not a DataGraph
embedded_data = embedder.embed(data)
except TypeError as error:
print(error)
def my_wrong_embedding_function(graph: DataGraph, param1: float) -> DataGraph:
return param1 # Not a DataGraph
embedder = GraphToGraphEmbedder(my_wrong_embedding_function, MyEmbedderConfig())
try:
graph = DataGraph.random_er(5, 0.5)
embedded_graph = embedder.embed(graph)
except TypeError as error:
print(error)