Skip to content

Reference

This part of the project documentation focuses on an information-oriented approach. Use it as a reference for the technical implementation of the Aeiva project code.

Aeiva API references

action

action

Action

Bases: Step

Represents an action that can be executed, extending from the Step class. An action is a tool with states and state management methods. It can execute functionality.

Source code in src/aeiva/action/action.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Action(Step):
    """
    Represents an action that can be executed, extending from the Step class.
    An action is a tool with states and state management methods. It can execute functionality.
    """

    def __init__(self, name: str, params: Dict[str, Any] = None,
                 id: str = None, dependent_ids: Optional[List[str]] = None, 
                 type: Optional[str] = None, description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None):
        super().__init__(name=name, params=params,
                         id=id, dependent_ids=dependent_ids,
                         type=type, description=description,
                         metadata=metadata)
        self.type = "Action"
        self.tool = Tool(name)
        self.result = None

    def reset(self) -> None:
        """
        Resets the step status, making it ready for re-execution.
        """
        self.result = None
        self.status = Status.NOT_EXECUTED

    async def execute(self, params: Dict[str, Any]) -> Any:
        if self.tool is None:
            raise ValueError(f"Action {self.id} has no tool assigned for execution.")

        self.start()
        try:
            result = await self.tool.execute(params)  # Assuming the tool's execute method is async
            self.end(success=True)
            self.result = result
            return result
        except Exception as e:
            self.end(success=False)
            raise RuntimeError(f"Action {self.id} failed: {str(e)}")
reset()

Resets the step status, making it ready for re-execution.

Source code in src/aeiva/action/action.py
24
25
26
27
28
29
def reset(self) -> None:
    """
    Resets the step status, making it ready for re-execution.
    """
    self.result = None
    self.status = Status.NOT_EXECUTED

action_system

ActionSystem

A concrete Action System responsible for translating Plans into executable Skills and managing the execution of Skills.

Source code in src/aeiva/action/action_system.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class ActionSystem:
    """
    A concrete Action System responsible for translating Plans into executable Skills
    and managing the execution of Skills.
    """

    def __init__(self, config: Dict):
        self.config = config
        self.state = {
            "current_skill": None,
            "execution_status": "Not Started",
        }
        self.tools = []
        self.skill = None

    def setup(self) -> None:
        if "tools" in self.config.keys():
            for tool_name in self.config["tools"]:
                self.tools.append(Tool.load_tool_schema(tool_name))
        print("ActionSystem setup complete.")

    def plan_to_skill(self, plan: Plan) -> Skill:
        actions = []

        for task in plan.steps:
            if isinstance(task, Task):
                action = Action(
                    name=task.name,
                    params=task.params,
                    id=task.id,
                    dependent_ids=task.dependent_ids,
                    type="Action",
                    description=task.description,
                    metadata=task.metadata
                )
                actions.append(action)
            elif isinstance(task, Plan):
                sub_skill = self.plan_to_skill(task)  # Recursively handle sub-plans
                actions.append(sub_skill)
            else:
                raise TypeError(f"Unexpected step type: {type(task)} in plan {plan.id}")

        if not actions:
            raise ValueError(f"The plan {plan.id} does not contain any valid actions or sub-plans.")

        return Skill(
            name=plan.name,
            steps=actions,
            id=plan.id,
            dependent_ids=plan.dependent_ids,
            type="Skill",
            description=plan.description,
            metadata=plan.metadata
        )

    async def execute(self, plan: Plan) -> None:
        self.state["execution_status"] = "Executing"

        try:
            self.skill = self.plan_to_skill(plan)            
            await self.skill.execute()            
            self.state["execution_status"] = "Completed" if self.skill.is_successful else "Failed"
        except Exception as e:
            self.state["execution_status"] = "Failed"
            self.handle_error(e)
            raise  # Ensure to re-throw the exception

    def handle_error(self, error: Exception) -> None:
        print(f"ActionSystem encountered an error: {error}")

experience

Experience

Bases: Procedure

Represents an experience, which is a structured composition of actions. Unlike a skill, an experience cannot be executed until it is validated and transformed into a skill.

Attributes:

Name Type Description
owner str

The person or agent who owns the experience.

reliable bool

A flag indicating whether the experience is reliable enough to be transformed into a skill.

Source code in src/aeiva/action/experience.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Experience(Procedure):
    """
    Represents an experience, which is a structured composition of actions.
    Unlike a skill, an experience cannot be executed until it is validated and transformed into a skill.

    Attributes:
        owner (str): The person or agent who owns the experience.
        reliable (bool): A flag indicating whether the experience is reliable enough to be transformed into a skill.
    """

    def __init__(self, name: str, steps: List[Union['Experience', Action]],
                 owner: Optional[str] = None, reliable: Optional[bool] = False,
                 id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,
                 type: Optional[str] = None, description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None):
        """
        Initializes a Skill by extending Procedure.
        """
        super().__init__(name=name, steps=steps,
                         id=id, dependent_ids=dependent_ids,
                         type=type, description=description,
                         metadata=metadata)
        self.type = "Experience"
        self.owner = owner  # The owner of the experience
        self.reliable = reliable  # Whether the experience can be transformed into a skill. 
                                  # We can use metadata to store some scored and decide whether it is reliable.

    @property
    def is_reliable(self) -> bool:
        """
        Checks if the experience is reliable enough to be transformed into a skill.
        """
        return self.reliable

    def mark_reliable(self) -> None:
        """
        Marks the experience as reliable, allowing it to be transformed into a skill.
        """
        self.reliable = True

    def to_skill(self) -> Skill:
        """
        Converts this experience into a skill, but only if the experience is marked as reliable.
        If the experience is not reliable, raises a ValueError.

        Returns:
            Skill: A new Skill object that is based on the actions from this experience.
        """
        if not self.reliable:
            raise ValueError(f"Experience {self.id} cannot be transformed into a skill because it is not marked as reliable.")

        # Create and return a new Skill instance
        return Skill(
            name=self.name,
            steps=self.steps,  # Use the same steps (actions) from the experience
            id=self.id,
            dependent_ids=self.dependent_ids,
            type="Skill",
            description=f"Skill derived from Experience: {self.description}", 
            metadata=self.metadata
        )

    def to_dict(self) -> Dict[str, Any]:
        """
        Returns a dictionary representation of the object.
        """
        experience_dict = super().to_dict()
        experience_dict.update({
            "owner": self.owner,
            "reliable": self.reliable,
        })
        return experience_dict
is_reliable: bool property

Checks if the experience is reliable enough to be transformed into a skill.

__init__(name, steps, owner=None, reliable=False, id=None, dependent_ids=None, type=None, description=None, metadata=None)

Initializes a Skill by extending Procedure.

Source code in src/aeiva/action/experience.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, name: str, steps: List[Union['Experience', Action]],
             owner: Optional[str] = None, reliable: Optional[bool] = False,
             id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,
             type: Optional[str] = None, description: Optional[str] = None,
             metadata: Optional[Dict[str, Any]] = None):
    """
    Initializes a Skill by extending Procedure.
    """
    super().__init__(name=name, steps=steps,
                     id=id, dependent_ids=dependent_ids,
                     type=type, description=description,
                     metadata=metadata)
    self.type = "Experience"
    self.owner = owner  # The owner of the experience
    self.reliable = reliable  # Whether the experience can be transformed into a skill. 
mark_reliable()

Marks the experience as reliable, allowing it to be transformed into a skill.

Source code in src/aeiva/action/experience.py
41
42
43
44
45
def mark_reliable(self) -> None:
    """
    Marks the experience as reliable, allowing it to be transformed into a skill.
    """
    self.reliable = True
to_dict()

Returns a dictionary representation of the object.

Source code in src/aeiva/action/experience.py
69
70
71
72
73
74
75
76
77
78
def to_dict(self) -> Dict[str, Any]:
    """
    Returns a dictionary representation of the object.
    """
    experience_dict = super().to_dict()
    experience_dict.update({
        "owner": self.owner,
        "reliable": self.reliable,
    })
    return experience_dict
to_skill()

Converts this experience into a skill, but only if the experience is marked as reliable. If the experience is not reliable, raises a ValueError.

Returns:

Name Type Description
Skill Skill

A new Skill object that is based on the actions from this experience.

Source code in src/aeiva/action/experience.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def to_skill(self) -> Skill:
    """
    Converts this experience into a skill, but only if the experience is marked as reliable.
    If the experience is not reliable, raises a ValueError.

    Returns:
        Skill: A new Skill object that is based on the actions from this experience.
    """
    if not self.reliable:
        raise ValueError(f"Experience {self.id} cannot be transformed into a skill because it is not marked as reliable.")

    # Create and return a new Skill instance
    return Skill(
        name=self.name,
        steps=self.steps,  # Use the same steps (actions) from the experience
        id=self.id,
        dependent_ids=self.dependent_ids,
        type="Skill",
        description=f"Skill derived from Experience: {self.description}", 
        metadata=self.metadata
    )

plan

Plan

Bases: Procedure

Represents a plan, which is a structured roadmap for achieving a goal by executing tasks and sub-plans. Inherits common functionality from Procedure.

Source code in src/aeiva/action/plan.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Plan(Procedure):
    """
    Represents a plan, which is a structured roadmap for achieving a goal by executing tasks and sub-plans.
    Inherits common functionality from Procedure.
    """

    def __init__(self, name: str, steps: List[Union['Plan', Task]],
                 id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,
                 type: Optional[str] = None, description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None):
        """
        Initializes a Skill by extending Procedure.
        """
        super().__init__(name=name, steps=steps,
                         id=id, dependent_ids=dependent_ids,
                         type=type, description=description,
                         metadata=metadata)
        self.type = "Plan"
__init__(name, steps, id=None, dependent_ids=None, type=None, description=None, metadata=None)

Initializes a Skill by extending Procedure.

Source code in src/aeiva/action/plan.py
11
12
13
14
15
16
17
18
19
20
21
22
def __init__(self, name: str, steps: List[Union['Plan', Task]],
             id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,
             type: Optional[str] = None, description: Optional[str] = None,
             metadata: Optional[Dict[str, Any]] = None):
    """
    Initializes a Skill by extending Procedure.
    """
    super().__init__(name=name, steps=steps,
                     id=id, dependent_ids=dependent_ids,
                     type=type, description=description,
                     metadata=metadata)
    self.type = "Plan"

procedure

Procedure

Abstract base class for composite structures like Plan and Skill. Contains shared attributes and methods for organizing and managing steps (e.g., tasks, sub-procedures) in a directed acyclic graph (DAG).

Source code in src/aeiva/action/procedure.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class Procedure:
    """
    Abstract base class for composite structures like Plan and Skill.
    Contains shared attributes and methods for organizing and managing steps (e.g., tasks, sub-procedures) 
    in a directed acyclic graph (DAG).
    """

    def __init__(self, name: str, steps: List[Union['Procedure', Step]],
                 id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,
                 type: Optional[str] = None, description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None,
                 *args, **kwargs):
        self.name = name
        self.steps = steps
        self.id = id
        self.dependent_ids = dependent_ids or []
        self.type = type
        self.description = description
        self.metadata = metadata or {}

        self.graph = nx.DiGraph()
        self.step_map = {step.id: step for step in steps}
        self.status = Status.NOT_EXECUTED

        # Add all steps as nodes in the graph
        for step in steps:
            self.graph.add_node(step)

        # Handle dependencies for steps
        for step in steps:
            for dep_id in step.dependent_ids:
                if dep_id in self.step_map:
                    self.graph.add_edge(self.step_map[dep_id], step)
                else:
                    raise ValueError(f"Dependency {dep_id} not found for step {step.id}.")

    def get_topological_sort(self):
        """
        Returns the steps in topologically sorted order based on the dependency graph.
        Ensures that all prerequisite steps are executed before the dependent ones.
        """
        try:
            return list(nx.topological_sort(self.graph))
        except nx.NetworkXUnfeasible:
            raise ValueError("The dependency graph contains cycles, which is not allowed in a procedure.")

    def reset(self) -> None:
        """
        Resets the status of the procedure and all its steps.
        """
        self.status = Status.NOT_EXECUTED
        for step in self.steps:
            step.reset()

    def start(self) -> None:
        """
        Marks the procedure as in progress. Raises an error if it's already in progress or finished.
        """
        if self.status != Status.NOT_EXECUTED:
            raise ValueError(f"{self.type} {self.id} has already been started or finished.")
        self.status = Status.EXECUTING

    def end(self, success: bool) -> None:
        """
        Marks the procedure as completed. Raises an error if it hasn't started yet.
        """
        if self.status != Status.EXECUTING:
            raise ValueError(f"Cannot finish {self.type} that hasn't started.")
        self.status = Status.SUCCESS if success else Status.FAIL

    @property
    def is_successful(self) -> bool:
        return all(step.is_successful for step in self.steps)

    @property
    def is_failed(self) -> bool:
        return any(step.is_failed for step in self.steps)

    @property
    def is_in_progress(self) -> bool:
        return any(step.is_in_progress for step in self.steps)

    @property
    def is_not_started(self) -> bool:
        return all(step.is_not_started for step in self.steps)

    @property
    def is_finished(self) -> bool:
        return all(step.is_finished for step in self.steps)

    def visualize(self, save_path: Optional[str] = None):
        """
        Visualizes the procedure's structure using networkx and matplotlib.
        """
        pos = nx.spring_layout(self.graph)  # Layout for the graph
        labels = {node: f"{node.id} ({node.description}, {node.status})" for node in self.graph.nodes()}

        # Draw the graph with labels
        nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)

        plt.title(f"{self.type} {self.description} Visualization")
        if save_path:
            plt.savefig(save_path)
        else:
            plt.show()

    def to_dict(self) -> Dict[str, Any]:
        return {
            "name": self.name,
            "steps": [step.to_dict() for step in self.steps],
            "id": self.id,
            "dependent_ids": self.dependent_ids,
            "type": self.type,
            "description": self.description,
            "metadata": self.metadata,
            "status": self.status
        }
end(success)

Marks the procedure as completed. Raises an error if it hasn't started yet.

Source code in src/aeiva/action/procedure.py
69
70
71
72
73
74
75
def end(self, success: bool) -> None:
    """
    Marks the procedure as completed. Raises an error if it hasn't started yet.
    """
    if self.status != Status.EXECUTING:
        raise ValueError(f"Cannot finish {self.type} that hasn't started.")
    self.status = Status.SUCCESS if success else Status.FAIL
get_topological_sort()

Returns the steps in topologically sorted order based on the dependency graph. Ensures that all prerequisite steps are executed before the dependent ones.

Source code in src/aeiva/action/procedure.py
43
44
45
46
47
48
49
50
51
def get_topological_sort(self):
    """
    Returns the steps in topologically sorted order based on the dependency graph.
    Ensures that all prerequisite steps are executed before the dependent ones.
    """
    try:
        return list(nx.topological_sort(self.graph))
    except nx.NetworkXUnfeasible:
        raise ValueError("The dependency graph contains cycles, which is not allowed in a procedure.")
reset()

Resets the status of the procedure and all its steps.

Source code in src/aeiva/action/procedure.py
53
54
55
56
57
58
59
def reset(self) -> None:
    """
    Resets the status of the procedure and all its steps.
    """
    self.status = Status.NOT_EXECUTED
    for step in self.steps:
        step.reset()
start()

Marks the procedure as in progress. Raises an error if it's already in progress or finished.

Source code in src/aeiva/action/procedure.py
61
62
63
64
65
66
67
def start(self) -> None:
    """
    Marks the procedure as in progress. Raises an error if it's already in progress or finished.
    """
    if self.status != Status.NOT_EXECUTED:
        raise ValueError(f"{self.type} {self.id} has already been started or finished.")
    self.status = Status.EXECUTING
visualize(save_path=None)

Visualizes the procedure's structure using networkx and matplotlib.

Source code in src/aeiva/action/procedure.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def visualize(self, save_path: Optional[str] = None):
    """
    Visualizes the procedure's structure using networkx and matplotlib.
    """
    pos = nx.spring_layout(self.graph)  # Layout for the graph
    labels = {node: f"{node.id} ({node.description}, {node.status})" for node in self.graph.nodes()}

    # Draw the graph with labels
    nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)

    plt.title(f"{self.type} {self.description} Visualization")
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

skill

Skill

Bases: Procedure

Represents a skill, which is a structured roadmap for executing actions. Skills are composed of actions and can be executed. Inherits common functionality from Procedure.

Source code in src/aeiva/action/skill.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Skill(Procedure):
    """
    Represents a skill, which is a structured roadmap for executing actions.
    Skills are composed of actions and can be executed.
    Inherits common functionality from Procedure.
    """

    def __init__(self, name: str, steps: List[Union['Skill', Action]],
                 id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,
                 type: Optional[str] = None, description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None):
        """
        Initializes a Skill by extending Procedure.
        """
        super().__init__(name=name, steps=steps,
                         id=id, dependent_ids=dependent_ids,
                         type=type, description=description,
                         metadata=metadata)
        self.type = "Skill"

    def get_topological_sort(self):
        """
        Returns the steps in topologically sorted order based on the dependency graph.
        Ensures that all prerequisite steps are executed before the dependent ones.
        """
        return list(nx.topological_sort(self.graph))

    async def execute(self):
        """
        Executes all actions in the skill based on the dependencies defined in the graph.
        This will run the actions asynchronously, respecting their dependencies.
        """
        self.start()

        # Perform topological sort right before execution
        sorted_steps = self.get_topological_sort()

        for step in sorted_steps:
            if isinstance(step, Action):
                print(f"Executing Action: {step.id} - {step.description}")
                await step.execute(step.params)  # Execute the action asynchronously
            elif isinstance(step, Skill):
                print(f"Executing Sub-Skill: {step.id}")
                await step.execute()  # If it's a sub-skill, execute the sub-skill

        self.end(success=self.is_successful)
__init__(name, steps, id=None, dependent_ids=None, type=None, description=None, metadata=None)

Initializes a Skill by extending Procedure.

Source code in src/aeiva/action/skill.py
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self, name: str, steps: List[Union['Skill', Action]],
             id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,
             type: Optional[str] = None, description: Optional[str] = None,
             metadata: Optional[Dict[str, Any]] = None):
    """
    Initializes a Skill by extending Procedure.
    """
    super().__init__(name=name, steps=steps,
                     id=id, dependent_ids=dependent_ids,
                     type=type, description=description,
                     metadata=metadata)
    self.type = "Skill"
execute() async

Executes all actions in the skill based on the dependencies defined in the graph. This will run the actions asynchronously, respecting their dependencies.

Source code in src/aeiva/action/skill.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
async def execute(self):
    """
    Executes all actions in the skill based on the dependencies defined in the graph.
    This will run the actions asynchronously, respecting their dependencies.
    """
    self.start()

    # Perform topological sort right before execution
    sorted_steps = self.get_topological_sort()

    for step in sorted_steps:
        if isinstance(step, Action):
            print(f"Executing Action: {step.id} - {step.description}")
            await step.execute(step.params)  # Execute the action asynchronously
        elif isinstance(step, Skill):
            print(f"Executing Sub-Skill: {step.id}")
            await step.execute()  # If it's a sub-skill, execute the sub-skill

    self.end(success=self.is_successful)
get_topological_sort()

Returns the steps in topologically sorted order based on the dependency graph. Ensures that all prerequisite steps are executed before the dependent ones.

Source code in src/aeiva/action/skill.py
28
29
30
31
32
33
def get_topological_sort(self):
    """
    Returns the steps in topologically sorted order based on the dependency graph.
    Ensures that all prerequisite steps are executed before the dependent ones.
    """
    return list(nx.topological_sort(self.graph))

status

Status

A class to hold status constants.

Source code in src/aeiva/action/status.py
1
2
3
4
5
6
7
8
class Status:
    """
    A class to hold status constants.
    """
    NOT_EXECUTED = "Not Executed"
    EXECUTING = "Executing"
    SUCCESS = "Success"
    FAIL = "Fail"

step

Step

Abstract base class for atomic units like Task and Action. Contains shared attributes and methods for managing their execution and dependencies.

Source code in src/aeiva/action/step.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class Step:
    """
    Abstract base class for atomic units like Task and Action.
    Contains shared attributes and methods for managing their execution and dependencies.
    """

    def __init__(self, name: str, params: Dict[str, Any] = None,
                 id: Optional[str] = None, dependent_ids: Optional[List[str]] = None, 
                 type: Optional[str] = None, description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None,
                 *args, **kwargs):
        self.name = name  # The name of the step. It can be a task/action/tool/api/function name
        self.params = params  # The parameters for this step. it can be a task/action/tool/api/function's params
        self.id = id  # Unique identifier for the step
        self.dependent_ids = dependent_ids or []  # List of IDs of steps that must be completed before this one
        self.type = type  # The type of this step, e.g., task or action
        self.description = description  # A description for this step
        self.metadata = metadata or {}  # Optional metadata (e.g., id, type, description, priority, etc.)
        self.status = Status.NOT_EXECUTED  # Initial status

    def reset(self) -> None:
        """
        Resets the step status, making it ready for re-execution.
        """
        self.status = Status.NOT_EXECUTED

    def start(self) -> None:
        """
        Marks the step as in progress. Raises an error if the step is already started or finished.
        """
        if self.status != Status.NOT_EXECUTED:
            raise ValueError(f"{self.type} {self.description} {self.id} has already been started or finished.")
        self.status = Status.EXECUTING

    def end(self, success: bool) -> None:
        """
        Marks the step as finished and indicates whether it was successful.
        Can only be called if the step is in progress.
        """
        if self.status != Status.EXECUTING:
            raise ValueError(f"Cannot finish a {self.type} that hasn't started.")
        self.status = Status.SUCCESS if success else Status.FAIL

    @property
    def is_successful(self) -> bool:
        """
        Returns True if the step was completed successfully.
        """
        return self.status == Status.SUCCESS

    @property
    def is_failed(self) -> bool:
        """
        Returns True if the step has finished but failed.
        """
        return self.status == Status.FAIL

    @property
    def is_in_progress(self) -> bool:
        """
        Returns True if the step is in progress (executing but not finished).
        """
        return self.status == Status.EXECUTING

    @property
    def is_not_started(self) -> bool:
        """
        Returns True if the step has not started yet.
        """
        return self.status == Status.NOT_EXECUTED

    @property
    def is_finished(self) -> bool:
        """
        Returns True if the step has finished execution, either successfully or failed.
        """
        return self.status == Status.SUCCESS or self.status == Status.FAIL

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the step into a dictionary representation.
        """
        return {
            "name": self.name,
            "params": self.params,
            "id": self.id,
            "dependent_ids": self.dependent_ids,
            "type": self.type,
            "description": self.description,
            "status": self.status,
            "metadata": self.metadata
        }
is_failed: bool property

Returns True if the step has finished but failed.

is_finished: bool property

Returns True if the step has finished execution, either successfully or failed.

is_in_progress: bool property

Returns True if the step is in progress (executing but not finished).

is_not_started: bool property

Returns True if the step has not started yet.

is_successful: bool property

Returns True if the step was completed successfully.

end(success)

Marks the step as finished and indicates whether it was successful. Can only be called if the step is in progress.

Source code in src/aeiva/action/step.py
38
39
40
41
42
43
44
45
def end(self, success: bool) -> None:
    """
    Marks the step as finished and indicates whether it was successful.
    Can only be called if the step is in progress.
    """
    if self.status != Status.EXECUTING:
        raise ValueError(f"Cannot finish a {self.type} that hasn't started.")
    self.status = Status.SUCCESS if success else Status.FAIL
reset()

Resets the step status, making it ready for re-execution.

Source code in src/aeiva/action/step.py
24
25
26
27
28
def reset(self) -> None:
    """
    Resets the step status, making it ready for re-execution.
    """
    self.status = Status.NOT_EXECUTED
start()

Marks the step as in progress. Raises an error if the step is already started or finished.

Source code in src/aeiva/action/step.py
30
31
32
33
34
35
36
def start(self) -> None:
    """
    Marks the step as in progress. Raises an error if the step is already started or finished.
    """
    if self.status != Status.NOT_EXECUTED:
        raise ValueError(f"{self.type} {self.description} {self.id} has already been started or finished.")
    self.status = Status.EXECUTING
to_dict()

Converts the step into a dictionary representation.

Source code in src/aeiva/action/step.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def to_dict(self) -> Dict[str, Any]:
    """
    Converts the step into a dictionary representation.
    """
    return {
        "name": self.name,
        "params": self.params,
        "id": self.id,
        "dependent_ids": self.dependent_ids,
        "type": self.type,
        "description": self.description,
        "status": self.status,
        "metadata": self.metadata
    }

task

Task

Bases: Step

Represents the fundamental unit of work, extending from the Step class. Inherits shared attributes and methods from Step and adds task-specific functionality.

Source code in src/aeiva/action/task.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Task(Step):
    """
    Represents the fundamental unit of work, extending from the Step class.
    Inherits shared attributes and methods from Step and adds task-specific functionality.
    """

    def __init__(self, name: str, params: Dict[str, Any] = None,
                 id: Optional[str] = None, dependent_ids: Optional[List[str]] = None, 
                 type: Optional[str] = None, description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None):
        super().__init__(name=name, params=params,
                         id=id, dependent_ids=dependent_ids,
                         type=type, description=description,
                         metadata=metadata)
        self.type = "Task"

    def show(self) -> None:
        print("---- Task Information ----")
        pprint(self.to_dict(), sort_dicts=False)
        print("---- End of Task ----")

agent

agent

Agent

Represents the agent that integrates perception, cognition, and action systems.

Source code in src/aeiva/agent/agent.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class Agent:
    """
    Represents the agent that integrates perception, cognition, and action systems.
    """
    def __init__(self, config: Dict):
        self.config_dict = config
        self.config = None
        self.event_bus = EventBus()
        self.perception_system = None
        self.cognition_system = None
        self.action_system = None

    def setup(self) -> None:
        """
        Set up all systems.
        """
        perception_config = self.config_dict.get('perception_config', {})
        cognition_config = self.config_dict  # NOTE: we didn't define a cognition config class yet.
        action_config = self.config_dict.get('action_config', {})

        self.perception_system = PerceptionSystem(perception_config, self.event_bus)
        self.cognition_system = CognitionSystem(cognition_config)
        self.action_system = ActionSystem(action_config)

        self.perception_system.setup()
        self.cognition_system.setup()
        self.action_system.setup()

    async def run(self) -> None:
        """
        Run the agent by connecting perception, cognition, and action systems using the event bus.
        """
        # Start the event bus within the running event loop
        self.event_bus.start()
        # Assign the current running loop to the EventBus
        self.event_bus.loop = asyncio.get_running_loop()
        # Set up event handlers
        self.setup_event_handlers()
        # Start the perception system
        await self.perception_system.start()

        # Keep the event loop running until interrupted
        try:
            while True:
                await asyncio.sleep(1)
        except KeyboardInterrupt:
            # Handle graceful shutdown
            self.perception_system.stop()
            await self.event_bus.wait_until_all_events_processed()
            self.event_bus.stop()
        except asyncio.CancelledError:
            pass
        except Exception as e:
            # logger.error(f"Unexpected error in agent run loop: {e}")
            print(f"Unexpected error in agent run loop: {e}", flush=True)
            await self.perception_system.stop()
            await self.event_bus.wait_until_all_events_processed()
            self.event_bus.stop()

    async def process_input(self, input_text: str) -> str:
        """
        Process input text and return the agent's response.
        """
        stream = self.config_dict.get("llm_gateway_config").get("llm_stream")
        use_async = self.config_dict.get("llm_gateway_config").get("llm_use_async")
        stimuli = Stimuli(signals=[Signal(data=input_text, modularity='text')])
        output = ""
        try:
            response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)
            async for chunk in response_gen:
                if isinstance(chunk, str):
                    # For streaming chunks
                    output += chunk
                elif isinstance(chunk, Thought) or isinstance(chunk, Plan):
                    # For non-streaming responses
                    output += chunk.content
        except Exception as e:
            logger.error(f"Error in response: {e}")
        return output

    def setup_event_handlers(self) -> None:
        """
        Set up event handlers for perception, cognition, and action events.
        """

        @self.event_bus.on('perception.stimuli')
        async def handle_stimuli(event: Event):
            # print("handle_stimuli called", flush=True)
            user_input = event.payload
            stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])
            #print(f"Received stimuli: {stimuli}", flush=True)
            # Process stimuli through cognition system
            #stimuli = [{"role": "user", "content": stimuli}]

            stream = self.config_dict.get("llm_gateway_config").get("llm_stream")
            use_async = self.config_dict.get("llm_gateway_config").get("llm_use_async")
            sys.stdout.write("\r\033[K")  # Return to start of the line and clear it\
            print("Response: ", end='', flush=True)

            try:
                response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)
                async for chunk in response_gen:
                    if isinstance(chunk, str):
                        # For streaming chunks
                        print(f"{chunk}", end='', flush=True)
                    elif isinstance(chunk, Thought) or isinstance(chunk, Plan):
                        # For non-streaming responses
                        print(f"{chunk.content}", end='', flush=True)
            except Exception as e:
                logger.error(f"Error in response: {e}")

            print("\nYou: ", end='', flush=True)

            # # Determine if output is a Plan or Thought
            # if isinstance(output, Plan):  # TODO: change later
            #     print("Output is a Plan", flush=True)
            #     await self.event_bus.emit('action.plan', payload=output)
            # elif isinstance(output, Thought):
            #     print("Output is a Thought", flush=True)
            #     print(f"Agent Response: {output.content}", flush=True)
            # else:
            #     print("Unknown output from cognition system.", flush=True)

        @self.event_bus.on('action.plan')
        async def handle_plan(event: Event):
            print("handle_plan called", flush=True)
            plan = event.payload
            await self.action_system.execute(plan)

        @self.event_bus.on('perception.gradio')
        async def handle_gradio_input(event: Event):
            """
            Handle input from Gradio and emit response.gradio events.
            """
            user_input = event.payload
            stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])

            stream = self.config_dict.get("llm_gateway_config").get("llm_stream")
            use_async = self.config_dict.get("llm_gateway_config").get("llm_use_async")
            logger.info(f"Handling Gradio input: {user_input} | Stream: {stream}")
            try:
                response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)

                async for chunk in response_gen:
                    if isinstance(chunk, str):
                        # For streaming chunks
                        await self.event_bus.emit('response.gradio', payload=chunk)
                    elif isinstance(chunk, Thought) or isinstance(chunk, Plan):
                        # For non-streaming responses
                        await self.event_bus.emit('response.gradio', payload=chunk.content if hasattr(chunk, 'content') else str(chunk))

                if stream:
                    await self.event_bus.emit('response.gradio', payload="<END_OF_RESPONSE>")
            except Exception as e:
                logger.error(f"Error in streaming response: {e}")
                await self.event_bus.emit('response.gradio', payload="An error occurred during response generation.")
                if stream:
                    await self.event_bus.emit('response.gradio', payload="<END_OF_RESPONSE>")
process_input(input_text) async

Process input text and return the agent's response.

Source code in src/aeiva/agent/agent.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
async def process_input(self, input_text: str) -> str:
    """
    Process input text and return the agent's response.
    """
    stream = self.config_dict.get("llm_gateway_config").get("llm_stream")
    use_async = self.config_dict.get("llm_gateway_config").get("llm_use_async")
    stimuli = Stimuli(signals=[Signal(data=input_text, modularity='text')])
    output = ""
    try:
        response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)
        async for chunk in response_gen:
            if isinstance(chunk, str):
                # For streaming chunks
                output += chunk
            elif isinstance(chunk, Thought) or isinstance(chunk, Plan):
                # For non-streaming responses
                output += chunk.content
    except Exception as e:
        logger.error(f"Error in response: {e}")
    return output
run() async

Run the agent by connecting perception, cognition, and action systems using the event bus.

Source code in src/aeiva/agent/agent.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
async def run(self) -> None:
    """
    Run the agent by connecting perception, cognition, and action systems using the event bus.
    """
    # Start the event bus within the running event loop
    self.event_bus.start()
    # Assign the current running loop to the EventBus
    self.event_bus.loop = asyncio.get_running_loop()
    # Set up event handlers
    self.setup_event_handlers()
    # Start the perception system
    await self.perception_system.start()

    # Keep the event loop running until interrupted
    try:
        while True:
            await asyncio.sleep(1)
    except KeyboardInterrupt:
        # Handle graceful shutdown
        self.perception_system.stop()
        await self.event_bus.wait_until_all_events_processed()
        self.event_bus.stop()
    except asyncio.CancelledError:
        pass
    except Exception as e:
        # logger.error(f"Unexpected error in agent run loop: {e}")
        print(f"Unexpected error in agent run loop: {e}", flush=True)
        await self.perception_system.stop()
        await self.event_bus.wait_until_all_events_processed()
        self.event_bus.stop()
setup()

Set up all systems.

Source code in src/aeiva/agent/agent.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def setup(self) -> None:
    """
    Set up all systems.
    """
    perception_config = self.config_dict.get('perception_config', {})
    cognition_config = self.config_dict  # NOTE: we didn't define a cognition config class yet.
    action_config = self.config_dict.get('action_config', {})

    self.perception_system = PerceptionSystem(perception_config, self.event_bus)
    self.cognition_system = CognitionSystem(cognition_config)
    self.action_system = ActionSystem(action_config)

    self.perception_system.setup()
    self.cognition_system.setup()
    self.action_system.setup()
setup_event_handlers()

Set up event handlers for perception, cognition, and action events.

Source code in src/aeiva/agent/agent.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def setup_event_handlers(self) -> None:
    """
    Set up event handlers for perception, cognition, and action events.
    """

    @self.event_bus.on('perception.stimuli')
    async def handle_stimuli(event: Event):
        # print("handle_stimuli called", flush=True)
        user_input = event.payload
        stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])
        #print(f"Received stimuli: {stimuli}", flush=True)
        # Process stimuli through cognition system
        #stimuli = [{"role": "user", "content": stimuli}]

        stream = self.config_dict.get("llm_gateway_config").get("llm_stream")
        use_async = self.config_dict.get("llm_gateway_config").get("llm_use_async")
        sys.stdout.write("\r\033[K")  # Return to start of the line and clear it\
        print("Response: ", end='', flush=True)

        try:
            response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)
            async for chunk in response_gen:
                if isinstance(chunk, str):
                    # For streaming chunks
                    print(f"{chunk}", end='', flush=True)
                elif isinstance(chunk, Thought) or isinstance(chunk, Plan):
                    # For non-streaming responses
                    print(f"{chunk.content}", end='', flush=True)
        except Exception as e:
            logger.error(f"Error in response: {e}")

        print("\nYou: ", end='', flush=True)

        # # Determine if output is a Plan or Thought
        # if isinstance(output, Plan):  # TODO: change later
        #     print("Output is a Plan", flush=True)
        #     await self.event_bus.emit('action.plan', payload=output)
        # elif isinstance(output, Thought):
        #     print("Output is a Thought", flush=True)
        #     print(f"Agent Response: {output.content}", flush=True)
        # else:
        #     print("Unknown output from cognition system.", flush=True)

    @self.event_bus.on('action.plan')
    async def handle_plan(event: Event):
        print("handle_plan called", flush=True)
        plan = event.payload
        await self.action_system.execute(plan)

    @self.event_bus.on('perception.gradio')
    async def handle_gradio_input(event: Event):
        """
        Handle input from Gradio and emit response.gradio events.
        """
        user_input = event.payload
        stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])

        stream = self.config_dict.get("llm_gateway_config").get("llm_stream")
        use_async = self.config_dict.get("llm_gateway_config").get("llm_use_async")
        logger.info(f"Handling Gradio input: {user_input} | Stream: {stream}")
        try:
            response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)

            async for chunk in response_gen:
                if isinstance(chunk, str):
                    # For streaming chunks
                    await self.event_bus.emit('response.gradio', payload=chunk)
                elif isinstance(chunk, Thought) or isinstance(chunk, Plan):
                    # For non-streaming responses
                    await self.event_bus.emit('response.gradio', payload=chunk.content if hasattr(chunk, 'content') else str(chunk))

            if stream:
                await self.event_bus.emit('response.gradio', payload="<END_OF_RESPONSE>")
        except Exception as e:
            logger.error(f"Error in streaming response: {e}")
            await self.event_bus.emit('response.gradio', payload="An error occurred during response generation.")
            if stream:
                await self.event_bus.emit('response.gradio', payload="<END_OF_RESPONSE>")

base_agent

BaseAgent

Bases: ABC

Abstract base class for autonomous agents with perception, cognition, and action capabilities.

Source code in src/aeiva/agent/base_agent.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class BaseAgent(ABC):
    """
    Abstract base class for autonomous agents with perception, cognition, and action capabilities.
    """

    def __init__(self, config: Any):
        """
        Initialize the agent with configuration.

        Args:
            config (Any): Configuration settings for the agent.
        """
        self.config = config
        self.state = self.initialize_state()  # can be a dict that includes: id, profile, motivation, goal, task, plan, etc.
        self.stop_event = asyncio.Event()

        # Systems will be initialized in the setup method
        self.perception_system: PerceptionSystem = None
        self.cognition_system: CognitionSystem = None
        self.action_system: ActionSystem = None

    @abstractmethod
    def initialize_state(self) -> Any:
        """
        Initialize the agent's state.

        Returns:
            Any: The initial state of the agent.
        """
        pass

    @abstractmethod
    def setup(self) -> None:
        """
        Set up the agent's components (perception, cognition, action, etc.).
        Perform any asynchronous initialization if necessary.
        """
        pass

    @abstractmethod
    async def cycle(self) -> None:
        """
        Execute one cycle of perception, cognition, and action.
        This method should be overridden to define the agent's behavior per cycle.
        """
        pass

    async def run(self) -> None:
        """
        Run the agent, continuously executing cycles until stopped.
        """
        await self.setup()
        cycle_interval = self.config.get('cycle_interval', 1.0)
        while not self.stop_event.is_set():
            try:
                await self.cycle()
            except Exception as e:
                self.handle_error(e)
            await asyncio.sleep(cycle_interval)

    def stop(self) -> None:
        """
        Signal the agent to stop running.
        """
        self.stop_event.set()

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during cycle execution.

        Args:
            error (Exception): The exception that was raised.
        """
        # Implement your error handling logic here (e.g., logging)
        print(f"Error during agent cycle: {error}")
__init__(config)

Initialize the agent with configuration.

Parameters:

Name Type Description Default
config Any

Configuration settings for the agent.

required
Source code in src/aeiva/agent/base_agent.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, config: Any):
    """
    Initialize the agent with configuration.

    Args:
        config (Any): Configuration settings for the agent.
    """
    self.config = config
    self.state = self.initialize_state()  # can be a dict that includes: id, profile, motivation, goal, task, plan, etc.
    self.stop_event = asyncio.Event()

    # Systems will be initialized in the setup method
    self.perception_system: PerceptionSystem = None
    self.cognition_system: CognitionSystem = None
    self.action_system: ActionSystem = None
cycle() abstractmethod async

Execute one cycle of perception, cognition, and action. This method should be overridden to define the agent's behavior per cycle.

Source code in src/aeiva/agent/base_agent.py
51
52
53
54
55
56
57
@abstractmethod
async def cycle(self) -> None:
    """
    Execute one cycle of perception, cognition, and action.
    This method should be overridden to define the agent's behavior per cycle.
    """
    pass
handle_error(error)

Handle errors that occur during cycle execution.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/agent/base_agent.py
78
79
80
81
82
83
84
85
86
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during cycle execution.

    Args:
        error (Exception): The exception that was raised.
    """
    # Implement your error handling logic here (e.g., logging)
    print(f"Error during agent cycle: {error}")
initialize_state() abstractmethod

Initialize the agent's state.

Returns:

Name Type Description
Any Any

The initial state of the agent.

Source code in src/aeiva/agent/base_agent.py
33
34
35
36
37
38
39
40
41
@abstractmethod
def initialize_state(self) -> Any:
    """
    Initialize the agent's state.

    Returns:
        Any: The initial state of the agent.
    """
    pass
run() async

Run the agent, continuously executing cycles until stopped.

Source code in src/aeiva/agent/base_agent.py
59
60
61
62
63
64
65
66
67
68
69
70
async def run(self) -> None:
    """
    Run the agent, continuously executing cycles until stopped.
    """
    await self.setup()
    cycle_interval = self.config.get('cycle_interval', 1.0)
    while not self.stop_event.is_set():
        try:
            await self.cycle()
        except Exception as e:
            self.handle_error(e)
        await asyncio.sleep(cycle_interval)
setup() abstractmethod

Set up the agent's components (perception, cognition, action, etc.). Perform any asynchronous initialization if necessary.

Source code in src/aeiva/agent/base_agent.py
43
44
45
46
47
48
49
@abstractmethod
def setup(self) -> None:
    """
    Set up the agent's components (perception, cognition, action, etc.).
    Perform any asynchronous initialization if necessary.
    """
    pass
stop()

Signal the agent to stop running.

Source code in src/aeiva/agent/base_agent.py
72
73
74
75
76
def stop(self) -> None:
    """
    Signal the agent to stop running.
    """
    self.stop_event.set()

cognition

brain

brain

Brain

Bases: ABC

Abstract base class representing the cognitive processing unit.

The Brain is responsible for processing input stimuli to generate cognitive states that the CognitionSystem will translate into actions.

Attributes:

Name Type Description
config Any

Configuration settings for the Brain.

state Any

The internal state of the Brain.

Source code in src/aeiva/cognition/brain/brain.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class Brain(ABC):
    """
    Abstract base class representing the cognitive processing unit.

    The Brain is responsible for processing input stimuli to generate cognitive states
    that the CognitionSystem will translate into actions.

    Attributes:
        config (Any): Configuration settings for the Brain.
        state (Any): The internal state of the Brain.
    """

    def __init__(self, config: Any):
        """
        Initialize the Brain with the provided configuration.

        Args:
            config (Any): Configuration settings for the Brain.
        """
        self.config = config
        self.state = self.init_state()

    @abstractmethod
    def init_state(self) -> Any:
        """
        Initialize the internal state of the Brain.

        This method should set up the initial state required for the Brain's operations.

        Returns:
            Any: The initial state of the Brain.
        """
        pass

    @abstractmethod
    def setup(self) -> None:
        """
        Asynchronously set up the Brain's components.

        This method should initialize any necessary components or resources
        based on the provided configuration.

        Raises:
            ConfigurationError: If the configuration is invalid or incomplete.
        """
        pass

    @abstractmethod
    async def think(self, stimuli: Any, *args, **kwargs) -> Any:
        """
        Asynchronously process input stimuli to update the cognitive state.

        Args:
            stimuli (Any): The input stimuli to process.

        Returns:
            Any: The updated cognitive state.

        Raises:
            ProcessingError: If processing the stimuli fails.
        """
        pass

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during cognitive processing.

        This method can be overridden to implement custom error handling logic.

        Args:
            error (Exception): The exception that was raised.
        """
        # Default error handling: log the error
        print(f"Brain encountered an error: {error}")
__init__(config)

Initialize the Brain with the provided configuration.

Parameters:

Name Type Description Default
config Any

Configuration settings for the Brain.

required
Source code in src/aeiva/cognition/brain/brain.py
19
20
21
22
23
24
25
26
27
def __init__(self, config: Any):
    """
    Initialize the Brain with the provided configuration.

    Args:
        config (Any): Configuration settings for the Brain.
    """
    self.config = config
    self.state = self.init_state()
handle_error(error)

Handle errors that occur during cognitive processing.

This method can be overridden to implement custom error handling logic.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/cognition/brain/brain.py
70
71
72
73
74
75
76
77
78
79
80
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during cognitive processing.

    This method can be overridden to implement custom error handling logic.

    Args:
        error (Exception): The exception that was raised.
    """
    # Default error handling: log the error
    print(f"Brain encountered an error: {error}")
init_state() abstractmethod

Initialize the internal state of the Brain.

This method should set up the initial state required for the Brain's operations.

Returns:

Name Type Description
Any Any

The initial state of the Brain.

Source code in src/aeiva/cognition/brain/brain.py
29
30
31
32
33
34
35
36
37
38
39
@abstractmethod
def init_state(self) -> Any:
    """
    Initialize the internal state of the Brain.

    This method should set up the initial state required for the Brain's operations.

    Returns:
        Any: The initial state of the Brain.
    """
    pass
setup() abstractmethod

Asynchronously set up the Brain's components.

This method should initialize any necessary components or resources based on the provided configuration.

Raises:

Type Description
ConfigurationError

If the configuration is invalid or incomplete.

Source code in src/aeiva/cognition/brain/brain.py
41
42
43
44
45
46
47
48
49
50
51
52
@abstractmethod
def setup(self) -> None:
    """
    Asynchronously set up the Brain's components.

    This method should initialize any necessary components or resources
    based on the provided configuration.

    Raises:
        ConfigurationError: If the configuration is invalid or incomplete.
    """
    pass
think(stimuli, *args, **kwargs) abstractmethod async

Asynchronously process input stimuli to update the cognitive state.

Parameters:

Name Type Description Default
stimuli Any

The input stimuli to process.

required

Returns:

Name Type Description
Any Any

The updated cognitive state.

Raises:

Type Description
ProcessingError

If processing the stimuli fails.

Source code in src/aeiva/cognition/brain/brain.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@abstractmethod
async def think(self, stimuli: Any, *args, **kwargs) -> Any:
    """
    Asynchronously process input stimuli to update the cognitive state.

    Args:
        stimuli (Any): The input stimuli to process.

    Returns:
        Any: The updated cognitive state.

    Raises:
        ProcessingError: If processing the stimuli fails.
    """
    pass

llm_brain

LLMBrain

Bases: Brain

Concrete implementation of the Brain, using an LLM to process stimuli and generate cognitive states.

This brain uses the LLMClient to communicate with a language model to process input stimuli and produce outputs.

Source code in src/aeiva/cognition/brain/llm_brain.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class LLMBrain(Brain):
    """
    Concrete implementation of the Brain, using an LLM to process stimuli
    and generate cognitive states.

    This brain uses the LLMClient to communicate with a language model to
    process input stimuli and produce outputs.
    """

    def __init__(self, config: Dict):
        """
        Initialize the LLMBrain with the provided LLM configuration.

        Args:
            config (LLMGatewayConfig): Configuration settings for the LLMBrain.
        """
        super().__init__(config)
        self.config_dict = config
        self.config = None
        self.llm_client = None

    def init_state(self) -> Any:
        """
        Initialize the internal state of the Brain.

        The state can track the ongoing conversation or task context.

        Returns:
            dict: Initial empty state.
        """
        return {"conversation": [], "cognitive_state": None}

    def setup(self) -> None:
        """
        Set up the Brain's components.

        For the LLMBrain, this might involve validating the LLM configuration
        and ensuring that all necessary resources are in place.
        """
        llm_conf_dict = self.config_dict.get('llm_gateway_config', {})
        self.config = LLMGatewayConfig(
            llm_api_key=llm_conf_dict.get('llm_api_key'),
            llm_model_name=llm_conf_dict.get('llm_model_name', 'gpt-4o'),
            llm_temperature=llm_conf_dict.get('llm_temperature', 0.7),
            llm_max_output_tokens=llm_conf_dict.get('llm_max_output_tokens', 10000),
            llm_use_async=llm_conf_dict.get('llm_use_async', False),
            llm_stream=llm_conf_dict.get('llm_stream', False)
        )
        self.llm_client = LLMClient(self.config)

        system_prompt = llm_conf_dict.get('llm_system_prompt', None)
        if system_prompt is not None:  # TODO: only add system prompt for llms that support it.
                self.state["conversation"] += [{ "role": "system", "content": system_prompt }]

        print("LLMBrain setup complete.")

    async def think(
            self,
            stimuli: Any,
            tools: List[Dict[str, Any]] = None,
            stream: bool = False,
            use_async: bool = False
            ) -> AsyncGenerator[str, None]:
        """
        Asynchronously process input stimuli to update the cognitive state.

        Args:
            stimuli (Any): The input stimuli to process.
            stream (bool): Whether to use streaming mode. Default is False.

        Returns:
            str: The full response in both streaming and non-streaming modes.
        """
        try:
            # Assume stimuli is a list of messages (conversation context)
            if not isinstance(stimuli, list):
                raise ValueError("Stimuli must be a list of messages.")

            self.state["conversation"] += stimuli  #!! NOTE: to let LLM remember the history. 

            if not use_async: # NOTE: stream mode only works when use_async!!!
                response = self.llm_client(self.state["conversation"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation
                # self.state["conversation"] += [{"role": "assistant", "content": response}]
                self.state["cognitive_state"] = response
                yield response
            elif stream:
                # Stream mode: collect all parts of the streamed response
                response = ""
                # messages = self.state["conversation"].copy()
                async for delta in self.llm_client(self.state["conversation"], tools=tools, stream=stream):  #!! NOTE: llm client will update conversation
                    response += delta  # Collect the streamed content
                    yield delta
                # self.state["conversation"] += [{"role": "assistant", "content": response}]
                self.state["cognitive_state"] = response
                #return response
            else:
                # messages = self.state["conversation"].copy()
                response = await self.llm_client(self.state["conversation"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation
                # self.state["conversation"] += [{"role": "assistant", "content": response}]
                self.state["cognitive_state"] = response
                yield response
                #return response

        except Exception as e:
            self.handle_error(e)
            raise

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during cognitive processing.

        Args:
            error (Exception): The exception that was raised.
        """
        super().handle_error(error)
        # Custom error handling logic for LLM-related issues
        print(f"LLMBrain encountered an error: {error}")
__init__(config)

Initialize the LLMBrain with the provided LLM configuration.

Parameters:

Name Type Description Default
config LLMGatewayConfig

Configuration settings for the LLMBrain.

required
Source code in src/aeiva/cognition/brain/llm_brain.py
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, config: Dict):
    """
    Initialize the LLMBrain with the provided LLM configuration.

    Args:
        config (LLMGatewayConfig): Configuration settings for the LLMBrain.
    """
    super().__init__(config)
    self.config_dict = config
    self.config = None
    self.llm_client = None
handle_error(error)

Handle errors that occur during cognitive processing.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/cognition/brain/llm_brain.py
116
117
118
119
120
121
122
123
124
125
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during cognitive processing.

    Args:
        error (Exception): The exception that was raised.
    """
    super().handle_error(error)
    # Custom error handling logic for LLM-related issues
    print(f"LLMBrain encountered an error: {error}")
init_state()

Initialize the internal state of the Brain.

The state can track the ongoing conversation or task context.

Returns:

Name Type Description
dict Any

Initial empty state.

Source code in src/aeiva/cognition/brain/llm_brain.py
30
31
32
33
34
35
36
37
38
39
def init_state(self) -> Any:
    """
    Initialize the internal state of the Brain.

    The state can track the ongoing conversation or task context.

    Returns:
        dict: Initial empty state.
    """
    return {"conversation": [], "cognitive_state": None}
setup()

Set up the Brain's components.

For the LLMBrain, this might involve validating the LLM configuration and ensuring that all necessary resources are in place.

Source code in src/aeiva/cognition/brain/llm_brain.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def setup(self) -> None:
    """
    Set up the Brain's components.

    For the LLMBrain, this might involve validating the LLM configuration
    and ensuring that all necessary resources are in place.
    """
    llm_conf_dict = self.config_dict.get('llm_gateway_config', {})
    self.config = LLMGatewayConfig(
        llm_api_key=llm_conf_dict.get('llm_api_key'),
        llm_model_name=llm_conf_dict.get('llm_model_name', 'gpt-4o'),
        llm_temperature=llm_conf_dict.get('llm_temperature', 0.7),
        llm_max_output_tokens=llm_conf_dict.get('llm_max_output_tokens', 10000),
        llm_use_async=llm_conf_dict.get('llm_use_async', False),
        llm_stream=llm_conf_dict.get('llm_stream', False)
    )
    self.llm_client = LLMClient(self.config)

    system_prompt = llm_conf_dict.get('llm_system_prompt', None)
    if system_prompt is not None:  # TODO: only add system prompt for llms that support it.
            self.state["conversation"] += [{ "role": "system", "content": system_prompt }]

    print("LLMBrain setup complete.")
think(stimuli, tools=None, stream=False, use_async=False) async

Asynchronously process input stimuli to update the cognitive state.

Parameters:

Name Type Description Default
stimuli Any

The input stimuli to process.

required
stream bool

Whether to use streaming mode. Default is False.

False

Returns:

Name Type Description
str AsyncGenerator[str, None]

The full response in both streaming and non-streaming modes.

Source code in src/aeiva/cognition/brain/llm_brain.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
async def think(
        self,
        stimuli: Any,
        tools: List[Dict[str, Any]] = None,
        stream: bool = False,
        use_async: bool = False
        ) -> AsyncGenerator[str, None]:
    """
    Asynchronously process input stimuli to update the cognitive state.

    Args:
        stimuli (Any): The input stimuli to process.
        stream (bool): Whether to use streaming mode. Default is False.

    Returns:
        str: The full response in both streaming and non-streaming modes.
    """
    try:
        # Assume stimuli is a list of messages (conversation context)
        if not isinstance(stimuli, list):
            raise ValueError("Stimuli must be a list of messages.")

        self.state["conversation"] += stimuli  #!! NOTE: to let LLM remember the history. 

        if not use_async: # NOTE: stream mode only works when use_async!!!
            response = self.llm_client(self.state["conversation"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation
            # self.state["conversation"] += [{"role": "assistant", "content": response}]
            self.state["cognitive_state"] = response
            yield response
        elif stream:
            # Stream mode: collect all parts of the streamed response
            response = ""
            # messages = self.state["conversation"].copy()
            async for delta in self.llm_client(self.state["conversation"], tools=tools, stream=stream):  #!! NOTE: llm client will update conversation
                response += delta  # Collect the streamed content
                yield delta
            # self.state["conversation"] += [{"role": "assistant", "content": response}]
            self.state["cognitive_state"] = response
            #return response
        else:
            # messages = self.state["conversation"].copy()
            response = await self.llm_client(self.state["conversation"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation
            # self.state["conversation"] += [{"role": "assistant", "content": response}]
            self.state["cognitive_state"] = response
            yield response
            #return response

    except Exception as e:
        self.handle_error(e)
        raise

cognition_system

CognitionSystem

Processes Stimuli into Observations, uses the Brain to generate Thoughts, and orchestrates output into Plans.

Source code in src/aeiva/cognition/cognition_system.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class CognitionSystem:
    """
    Processes Stimuli into Observations, uses the Brain to generate Thoughts, and orchestrates output into Plans.
    """
    def __init__(self, config: Dict):
        self.config_dict = config
        self.config = None
        self.input_interpreter = None
        self.brain = None
        self.output_orchestrator = None
        self.memory = None
        self.emotion = None
        self.world_model = None
        self.state = self.init_state()

    def init_state(self) -> Dict[str, Any]:
        return {
            "cognitive_state": None,
            "last_input": None,
            "last_output": None
        }

    def setup(self) -> None:
        """
        Set up the cognition system's components.
        """
        self.brain = LLMBrain(config=self.config_dict)
        self.memory = MemoryPalace(config=self.config_dict)
        self.emotion = SimpleEmotion()  # TODO: replace
        self.world_model = SimpleWorldModel()  # TODO: replace
        self.input_interpreter = SimpleInputInterpreter()  # TODO: replace
        self.output_orchestrator = SimpleOutputOrchestrator()  # TODO: replace

        self.brain.setup()
        self.memory.setup()
        self.world_model.setup()
        self.emotion.setup()
        self.input_interpreter.setup()
        self.output_orchestrator.setup()

    def handle_error(self, error: Exception) -> None:
        print(f"CognitionSystem encountered an error: {error}")

    async def think(
            self,
            stimuli: Stimuli,
            tools: List[Dict[str, Any]] = None,
            stream: bool=False,
            use_async: bool=False
            ) -> AsyncGenerator[str, None]:
        """
        Processes stimuli and produces a thought or plan.

        Args:
            stimuli (Stimuli): The input stimuli.
            stream (bool): Whether to use streaming mode.
            tools (List[Dict[str, Any]]): Optional tools for function calls.

        Yields:
            str: Chunks of the assistant's response.
        """
        self.state["last_input"] = stimuli

        # Step 1: Use InputInterpreter to process stimuli into observation
        if self.input_interpreter.gate(stimuli):
            observation = await self.input_interpreter.interpret(stimuli)
        else:
            # Directly pass stimuli as observation (assuming it's acceptable)
            observation = Observation(data=stimuli.to_dict())

        # Step 2: Brain processes the observation into a thought or plan
        brain_input = [{"role": "user", "content": observation.data}]
        # Initiate brain processing
        response_gen = self.brain.think(brain_input, tools=tools, stream=stream, use_async=use_async)

        async for chunk in response_gen:
            if isinstance(chunk, str):
                # Streaming chunk or full response in non-streaming mode
                yield chunk
            elif isinstance(chunk, Thought):
                thought = chunk
                self.state["cognitive_state"] = thought

                # Step 3: Use OutputOrchestrator if applicable
                if self.output_orchestrator.gate(thought):
                    plan = await self.output_orchestrator.orchestrate(thought)
                    self.state["last_output"] = plan
                    yield plan.content if hasattr(plan, 'content') else str(plan)
                else:
                    self.state["last_output"] = thought
                    yield thought.content
            elif isinstance(chunk, Plan):
                plan = chunk
                self.state["last_output"] = plan
                yield plan.content if hasattr(plan, 'content') else str(plan)
            else:
                # Handle unexpected chunk types
                #logger.warning(f"Unexpected chunk type: {type(chunk)}")
                yield str(chunk)
setup()

Set up the cognition system's components.

Source code in src/aeiva/cognition/cognition_system.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def setup(self) -> None:
    """
    Set up the cognition system's components.
    """
    self.brain = LLMBrain(config=self.config_dict)
    self.memory = MemoryPalace(config=self.config_dict)
    self.emotion = SimpleEmotion()  # TODO: replace
    self.world_model = SimpleWorldModel()  # TODO: replace
    self.input_interpreter = SimpleInputInterpreter()  # TODO: replace
    self.output_orchestrator = SimpleOutputOrchestrator()  # TODO: replace

    self.brain.setup()
    self.memory.setup()
    self.world_model.setup()
    self.emotion.setup()
    self.input_interpreter.setup()
    self.output_orchestrator.setup()
think(stimuli, tools=None, stream=False, use_async=False) async

Processes stimuli and produces a thought or plan.

Parameters:

Name Type Description Default
stimuli Stimuli

The input stimuli.

required
stream bool

Whether to use streaming mode.

False
tools List[Dict[str, Any]]

Optional tools for function calls.

None

Yields:

Name Type Description
str AsyncGenerator[str, None]

Chunks of the assistant's response.

Source code in src/aeiva/cognition/cognition_system.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
async def think(
        self,
        stimuli: Stimuli,
        tools: List[Dict[str, Any]] = None,
        stream: bool=False,
        use_async: bool=False
        ) -> AsyncGenerator[str, None]:
    """
    Processes stimuli and produces a thought or plan.

    Args:
        stimuli (Stimuli): The input stimuli.
        stream (bool): Whether to use streaming mode.
        tools (List[Dict[str, Any]]): Optional tools for function calls.

    Yields:
        str: Chunks of the assistant's response.
    """
    self.state["last_input"] = stimuli

    # Step 1: Use InputInterpreter to process stimuli into observation
    if self.input_interpreter.gate(stimuli):
        observation = await self.input_interpreter.interpret(stimuli)
    else:
        # Directly pass stimuli as observation (assuming it's acceptable)
        observation = Observation(data=stimuli.to_dict())

    # Step 2: Brain processes the observation into a thought or plan
    brain_input = [{"role": "user", "content": observation.data}]
    # Initiate brain processing
    response_gen = self.brain.think(brain_input, tools=tools, stream=stream, use_async=use_async)

    async for chunk in response_gen:
        if isinstance(chunk, str):
            # Streaming chunk or full response in non-streaming mode
            yield chunk
        elif isinstance(chunk, Thought):
            thought = chunk
            self.state["cognitive_state"] = thought

            # Step 3: Use OutputOrchestrator if applicable
            if self.output_orchestrator.gate(thought):
                plan = await self.output_orchestrator.orchestrate(thought)
                self.state["last_output"] = plan
                yield plan.content if hasattr(plan, 'content') else str(plan)
            else:
                self.state["last_output"] = thought
                yield thought.content
        elif isinstance(chunk, Plan):
            plan = chunk
            self.state["last_output"] = plan
            yield plan.content if hasattr(plan, 'content') else str(plan)
        else:
            # Handle unexpected chunk types
            #logger.warning(f"Unexpected chunk type: {type(chunk)}")
            yield str(chunk)

emotion

emotion

ConfigurationError

Bases: Exception

Exception raised for errors in the configuration.

Source code in src/aeiva/cognition/emotion/emotion.py
12
13
14
class ConfigurationError(Exception):
    """Exception raised for errors in the configuration."""
    pass
Emotion

Bases: ABC, Generic[T]

Abstract base class representing the Emotion system of an agent with generic state type.

The Emotion system manages the agent's emotional states, allowing it to respond to various stimuli in an emotionally coherent manner.

Attributes:

Name Type Description
config Dict[str, Any]

Configuration settings for the Emotion system.

state T

The internal emotional state of the agent, defined by subclasses.

Source code in src/aeiva/cognition/emotion/emotion.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class Emotion(ABC, Generic[T]):
    """
    Abstract base class representing the Emotion system of an agent with generic state type.

    The Emotion system manages the agent's emotional states, allowing it to respond
    to various stimuli in an emotionally coherent manner.

    Attributes:
        config (Dict[str, Any]): Configuration settings for the Emotion system.
        state (T): The internal emotional state of the agent, defined by subclasses.
    """

    def __init__(self, config: Dict[str, Any]):
        """
        Initialize the Emotion system with the provided configuration.

        Args:
            config (Dict[str, Any]): Configuration settings for the Emotion system.
        """
        self.config = config
        self.state = self.init_state()

    @abstractmethod
    def init_state(self) -> T:
        """
        Initialize the internal emotional state of the Emotion system.

        This method should set up the initial emotional state required for the
        Emotion system's operations.

        Returns:
            T: The initial emotional state of the agent.
        """
        pass

    @abstractmethod
    async def setup(self) -> None:
        """
        Asynchronously set up the Emotion system's components.

        This method should initialize any necessary components or resources
        based on the provided configuration.

        Raises:
            ConfigurationError: If the configuration is invalid or incomplete.
        """
        pass

    @abstractmethod
    async def update(self, input_data: Dict[str, Any]) -> None:
        """
        Asynchronously update the emotional state based on input data.

        Args:
            input_data (Dict[str, Any]): The data or stimuli that influence the emotional state.

        Raises:
            UpdateError: If updating the emotional state fails.
        """
        pass

    @abstractmethod
    def regulate(self, strategy: str) -> None:
        """
        Regulate the emotional state using a specified strategy.

        Args:
            strategy (str): The regulation strategy to apply (e.g., 'suppression', 'amplification').

        Raises:
            RegulationError: If the regulation strategy is invalid or fails.
        """
        pass

    @abstractmethod
    def express(self) -> str:
        """
        Generate a representation of the current emotional state.

        Returns:
            str: A string describing the current emotion (e.g., "I feel happy!").
        """
        pass

    @abstractmethod
    def serialize(self) -> str:
        """
        Serialize the current emotional state into a string format.

        Returns:
            str: Serialized emotional state.
        """
        pass

    @abstractmethod
    def deserialize(self, data: str) -> None:
        """
        Deserialize the emotional state from a string format.

        Args:
            data (str): Serialized emotional state.
        """
        pass

    def get_current_state(self) -> T:
        """
        Retrieve the current emotional state of the agent.

        Returns:
            T: The current emotional state.
        """
        return self.state

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during emotional processing.

        This method can be overridden to implement custom error handling logic.

        Args:
            error (Exception): The exception that was raised.
        """
        pass
__init__(config)

Initialize the Emotion system with the provided configuration.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration settings for the Emotion system.

required
Source code in src/aeiva/cognition/emotion/emotion.py
38
39
40
41
42
43
44
45
46
def __init__(self, config: Dict[str, Any]):
    """
    Initialize the Emotion system with the provided configuration.

    Args:
        config (Dict[str, Any]): Configuration settings for the Emotion system.
    """
    self.config = config
    self.state = self.init_state()
deserialize(data) abstractmethod

Deserialize the emotional state from a string format.

Parameters:

Name Type Description Default
data str

Serialized emotional state.

required
Source code in src/aeiva/cognition/emotion/emotion.py
120
121
122
123
124
125
126
127
128
@abstractmethod
def deserialize(self, data: str) -> None:
    """
    Deserialize the emotional state from a string format.

    Args:
        data (str): Serialized emotional state.
    """
    pass
express() abstractmethod

Generate a representation of the current emotional state.

Returns:

Name Type Description
str str

A string describing the current emotion (e.g., "I feel happy!").

Source code in src/aeiva/cognition/emotion/emotion.py
100
101
102
103
104
105
106
107
108
@abstractmethod
def express(self) -> str:
    """
    Generate a representation of the current emotional state.

    Returns:
        str: A string describing the current emotion (e.g., "I feel happy!").
    """
    pass
get_current_state()

Retrieve the current emotional state of the agent.

Returns:

Name Type Description
T T

The current emotional state.

Source code in src/aeiva/cognition/emotion/emotion.py
130
131
132
133
134
135
136
137
def get_current_state(self) -> T:
    """
    Retrieve the current emotional state of the agent.

    Returns:
        T: The current emotional state.
    """
    return self.state
handle_error(error)

Handle errors that occur during emotional processing.

This method can be overridden to implement custom error handling logic.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/cognition/emotion/emotion.py
139
140
141
142
143
144
145
146
147
148
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during emotional processing.

    This method can be overridden to implement custom error handling logic.

    Args:
        error (Exception): The exception that was raised.
    """
    pass
init_state() abstractmethod

Initialize the internal emotional state of the Emotion system.

This method should set up the initial emotional state required for the Emotion system's operations.

Returns:

Name Type Description
T T

The initial emotional state of the agent.

Source code in src/aeiva/cognition/emotion/emotion.py
48
49
50
51
52
53
54
55
56
57
58
59
@abstractmethod
def init_state(self) -> T:
    """
    Initialize the internal emotional state of the Emotion system.

    This method should set up the initial emotional state required for the
    Emotion system's operations.

    Returns:
        T: The initial emotional state of the agent.
    """
    pass
regulate(strategy) abstractmethod

Regulate the emotional state using a specified strategy.

Parameters:

Name Type Description Default
strategy str

The regulation strategy to apply (e.g., 'suppression', 'amplification').

required

Raises:

Type Description
RegulationError

If the regulation strategy is invalid or fails.

Source code in src/aeiva/cognition/emotion/emotion.py
87
88
89
90
91
92
93
94
95
96
97
98
@abstractmethod
def regulate(self, strategy: str) -> None:
    """
    Regulate the emotional state using a specified strategy.

    Args:
        strategy (str): The regulation strategy to apply (e.g., 'suppression', 'amplification').

    Raises:
        RegulationError: If the regulation strategy is invalid or fails.
    """
    pass
serialize() abstractmethod

Serialize the current emotional state into a string format.

Returns:

Name Type Description
str str

Serialized emotional state.

Source code in src/aeiva/cognition/emotion/emotion.py
110
111
112
113
114
115
116
117
118
@abstractmethod
def serialize(self) -> str:
    """
    Serialize the current emotional state into a string format.

    Returns:
        str: Serialized emotional state.
    """
    pass
setup() abstractmethod async

Asynchronously set up the Emotion system's components.

This method should initialize any necessary components or resources based on the provided configuration.

Raises:

Type Description
ConfigurationError

If the configuration is invalid or incomplete.

Source code in src/aeiva/cognition/emotion/emotion.py
61
62
63
64
65
66
67
68
69
70
71
72
@abstractmethod
async def setup(self) -> None:
    """
    Asynchronously set up the Emotion system's components.

    This method should initialize any necessary components or resources
    based on the provided configuration.

    Raises:
        ConfigurationError: If the configuration is invalid or incomplete.
    """
    pass
update(input_data) abstractmethod async

Asynchronously update the emotional state based on input data.

Parameters:

Name Type Description Default
input_data Dict[str, Any]

The data or stimuli that influence the emotional state.

required

Raises:

Type Description
UpdateError

If updating the emotional state fails.

Source code in src/aeiva/cognition/emotion/emotion.py
74
75
76
77
78
79
80
81
82
83
84
85
@abstractmethod
async def update(self, input_data: Dict[str, Any]) -> None:
    """
    Asynchronously update the emotional state based on input data.

    Args:
        input_data (Dict[str, Any]): The data or stimuli that influence the emotional state.

    Raises:
        UpdateError: If updating the emotional state fails.
    """
    pass
RegulationError

Bases: Exception

Exception raised for errors during emotion regulation.

Source code in src/aeiva/cognition/emotion/emotion.py
20
21
22
class RegulationError(Exception):
    """Exception raised for errors during emotion regulation."""
    pass
UpdateError

Bases: Exception

Exception raised for errors during emotion state updates.

Source code in src/aeiva/cognition/emotion/emotion.py
16
17
18
class UpdateError(Exception):
    """Exception raised for errors during emotion state updates."""
    pass

emotion_categorical

CategoricalEmotionState

Represents the emotional state in a Categorical Model.

Source code in src/aeiva/cognition/emotion/emotion_categorical.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class CategoricalEmotionState:
    """
    Represents the emotional state in a Categorical Model.
    """
    def __init__(self, emotion_label: str = "neutral"):
        self.emotion_label = emotion_label

    def to_dict(self) -> Dict[str, Any]:
        return {
            'emotion_label': self.emotion_label
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return CategoricalEmotionState(
            emotion_label=data.get('emotion_label', 'neutral')
        )

emotion_category

CategoryEmotionState dataclass

Represents the emotional state in a Category-Based Model with extensive categories.

Attributes:

Name Type Description
emotion_label str

The current emotion category.

intensity float

The intensity of the current emotion (range: 0.0 to 1.0).

Source code in src/aeiva/cognition/emotion/emotion_category.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@dataclass
class CategoryEmotionState:
    """
    Represents the emotional state in a Category-Based Model with extensive categories.

    Attributes:
        emotion_label (str): The current emotion category.
        intensity (float): The intensity of the current emotion (range: 0.0 to 1.0).
    """
    emotion_label: str = "neutral"
    intensity: float = 0.0  # Optional: Represents the strength of the emotion

    def to_dict(self) -> Dict[str, Any]:
        return {
            'emotion_label': self.emotion_label,
            'intensity': self.intensity
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return CategoryEmotionState(
            emotion_label=data.get('emotion_label', 'neutral'),
            intensity=data.get('intensity', 0.0)
        )

emotion_circumplex

CircumplexEmotionState

Represents the emotional state in the Circumplex Model.

Source code in src/aeiva/cognition/emotion/emotion_circumplex.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class CircumplexEmotionState:
    """
    Represents the emotional state in the Circumplex Model.
    """
    def __init__(self, valence: float = 0.0, arousal: float = 0.0):
        self.valence = valence  # Range: [-1.0, 1.0]
        self.arousal = arousal  # Range: [-1.0, 1.0]

    def to_dict(self) -> Dict[str, Any]:
        return {
            'valence': self.valence,
            'arousal': self.arousal
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return CircumplexEmotionState(
            valence=data.get('valence', 0.0),
            arousal=data.get('arousal', 0.0)
        )

emotion_componential

ComponentialEmotionState dataclass

Represents the emotional state based on the Componential Model.

Attributes:

Name Type Description
emotion_label str

Current emotion category.

intensity float

Intensity of the emotion (0.0 to 1.0).

Source code in src/aeiva/cognition/emotion/emotion_componential.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@dataclass
class ComponentialEmotionState:
    """
    Represents the emotional state based on the Componential Model.

    Attributes:
        emotion_label (str): Current emotion category.
        intensity (float): Intensity of the emotion (0.0 to 1.0).
    """
    emotion_label: str = "neutral"
    intensity: float = 0.0

    def to_dict(self) -> Dict[str, Any]:
        return {
            'emotion_label': self.emotion_label,
            'intensity': self.intensity
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return ComponentialEmotionState(
            emotion_label=data.get('emotion_label', 'neutral'),
            intensity=data.get('intensity', 0.0)
        )

emotion_hybrid

HybridEmotionState

Represents the emotional state in the Hybrid Categorical-Dimensional Model.

Source code in src/aeiva/cognition/emotion/emotion_hybrid.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class HybridEmotionState:
    """
    Represents the emotional state in the Hybrid Categorical-Dimensional Model.
    """
    def __init__(self, emotion_label: str = "neutral", valence: float = 0.0, arousal: float = 0.0):
        self.emotion_label = emotion_label  # Categorical label
        self.valence = valence              # Dimensional valence
        self.arousal = arousal              # Dimensional arousal

    def to_dict(self) -> Dict[str, Any]:
        return {
            'emotion_label': self.emotion_label,
            'valence': self.valence,
            'arousal': self.arousal
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return HybridEmotionState(
            emotion_label=data.get('emotion_label', 'neutral'),
            valence=data.get('valence', 0.0),
            arousal=data.get('arousal', 0.0)
        )

emotion_occ

OCCEmotionState

Represents the emotional state in the OCC Appraisal-Based Model.

Source code in src/aeiva/cognition/emotion/emotion_occ.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class OCCEmotionState:
    """
    Represents the emotional state in the OCC Appraisal-Based Model.
    """
    def __init__(self, emotion_categories: Dict[str, float] = None):
        """
        Initialize the OCC emotion state with emotion categories and their intensities.
        """
        # Initialize with zero intensities if not provided
        self.emotion_categories = emotion_categories if emotion_categories else {
            'joy': 0.0,
            'sadness': 0.0,
            'anger': 0.0,
            'fear': 0.0,
            'surprise': 0.0,
            'disgust': 0.0
        }

    def to_dict(self) -> Dict[str, Any]:
        return {
            'emotion_categories': self.emotion_categories
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return OCCEmotionState(
            emotion_categories=data.get('emotion_categories', {})
        )
__init__(emotion_categories=None)

Initialize the OCC emotion state with emotion categories and their intensities.

Source code in src/aeiva/cognition/emotion/emotion_occ.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(self, emotion_categories: Dict[str, float] = None):
    """
    Initialize the OCC emotion state with emotion categories and their intensities.
    """
    # Initialize with zero intensities if not provided
    self.emotion_categories = emotion_categories if emotion_categories else {
        'joy': 0.0,
        'sadness': 0.0,
        'anger': 0.0,
        'fear': 0.0,
        'surprise': 0.0,
        'disgust': 0.0
    }

emotion_pad

PADEmotionState

Represents the emotional state in the PAD Model.

Source code in src/aeiva/cognition/emotion/emotion_pad.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class PADEmotionState:
    """
    Represents the emotional state in the PAD Model.
    """
    def __init__(self, pleasure: float = 0.0, arousal: float = 0.0, dominance: float = 0.0):
        self.pleasure = pleasure      # Range: [-1.0, 1.0]
        self.arousal = arousal        # Range: [-1.0, 1.0]
        self.dominance = dominance    # Range: [-1.0, 1.0]

    def to_dict(self) -> Dict[str, Any]:
        return {
            'pleasure': self.pleasure,
            'arousal': self.arousal,
            'dominance': self.dominance
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return PADEmotionState(
            pleasure=data.get('pleasure', 0.0),
            arousal=data.get('arousal', 0.0),
            dominance=data.get('dominance', 0.0)
        )

emotion_plutchik

PlutchikEmotionState dataclass

Represents the emotional state in Plutchik's Wheel of Emotions.

Attributes:

Name Type Description
joy float

Intensity of Joy.

trust float

Intensity of Trust.

fear float

Intensity of Fear.

surprise float

Intensity of Surprise.

sadness float

Intensity of Sadness.

disgust float

Intensity of Disgust.

anger float

Intensity of Anger.

anticipation float

Intensity of Anticipation.

Source code in src/aeiva/cognition/emotion/emotion_plutchik.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@dataclass
class PlutchikEmotionState:
    """
    Represents the emotional state in Plutchik's Wheel of Emotions.

    Attributes:
        joy (float): Intensity of Joy.
        trust (float): Intensity of Trust.
        fear (float): Intensity of Fear.
        surprise (float): Intensity of Surprise.
        sadness (float): Intensity of Sadness.
        disgust (float): Intensity of Disgust.
        anger (float): Intensity of Anger.
        anticipation (float): Intensity of Anticipation.
    """
    joy: float = 0.0
    trust: float = 0.0
    fear: float = 0.0
    surprise: float = 0.0
    sadness: float = 0.0
    disgust: float = 0.0
    anger: float = 0.0
    anticipation: float = 0.0

    def to_dict(self) -> Dict[str, Any]:
        return {
            'joy': self.joy,
            'trust': self.trust,
            'fear': self.fear,
            'surprise': self.surprise,
            'sadness': self.sadness,
            'disgust': self.disgust,
            'anger': self.anger,
            'anticipation': self.anticipation
        }

    @staticmethod
    def from_dict(data: Dict[str, Any]):
        return PlutchikEmotionState(
            joy=data.get('joy', 0.0),
            trust=data.get('trust', 0.0),
            fear=data.get('fear', 0.0),
            surprise=data.get('surprise', 0.0),
            sadness=data.get('sadness', 0.0),
            disgust=data.get('disgust', 0.0),
            anger=data.get('anger', 0.0),
            anticipation=data.get('anticipation', 0.0)
        )

exceptions

ConfigurationError

Bases: Exception

Exception raised for errors in the configuration.

Source code in src/aeiva/cognition/emotion/exceptions.py
3
4
5
class ConfigurationError(Exception):
    """Exception raised for errors in the configuration."""
    pass
RegulationError

Bases: Exception

Exception raised for errors during emotion regulation.

Source code in src/aeiva/cognition/emotion/exceptions.py
11
12
13
class RegulationError(Exception):
    """Exception raised for errors during emotion regulation."""
    pass
UpdateError

Bases: Exception

Exception raised for errors during emotion state updates.

Source code in src/aeiva/cognition/emotion/exceptions.py
7
8
9
class UpdateError(Exception):
    """Exception raised for errors during emotion state updates."""
    pass

memory

memory

Memory

Bases: ABC

Abstract base class for memory operations in the intelligent agent.

This class defines methods corresponding to different layers of memory processing, such as creating, filtering, grouping, deriving, structuring, skillizing, embedding, and parameterizing memory units.

Source code in src/aeiva/cognition/memory/memory.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class Memory(ABC):
    """
    Abstract base class for memory operations in the intelligent agent.

    This class defines methods corresponding to different layers of memory processing,
    such as creating, filtering, grouping, deriving, structuring, skillizing, embedding,
    and parameterizing memory units.
    """

    def __init__(self, config: Any):
        """
        Initialize the Memory system with the provided configuration.

        Args:
            config (Any): Configuration settings for the Memory system.
        """
        self.config = config

    @abstractmethod
    def setup(self) -> None:
        """
        Asynchronously set up the Memory system's components.

        This method should initialize any necessary components or resources based on the provided configuration.

        Raises:
            ConfigurationError: If the configuration is invalid or incomplete.
        """
        pass

    @abstractmethod
    def create(self, content: Any, **kwargs) -> MemoryUnit:
        """
        Creates a new memory unit with the given content and metadata.

        Args:
            content (Any): The core content of the memory unit.
            **kwargs: Additional metadata for the memory unit.

        Returns:
            MemoryUnit: The created memory unit.
        """
        pass

    @abstractmethod
    def get(self, unit_id: str) -> MemoryUnit:
        """
        Retrieves a memory unit by its unique identifier.

        Args:
            unit_id (str): The unique identifier of the memory unit.

        Returns:
            MemoryUnit: The retrieved memory unit.
        """
        pass

    @abstractmethod
    def update(self, unit_id: str, updates: Dict[str, Any]) -> None:
        """
        Updates a memory unit with the given updates.

        Args:
            unit_id (str): The unique identifier of the memory unit.
            updates (Dict[str, Any]): A dictionary of fields to update.
        """
        pass

    @abstractmethod
    def delete(self, unit_id: str) -> None:
        """
        Deletes a memory unit by its unique identifier.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        pass

    @abstractmethod
    def get_all(self) -> List[MemoryUnit]:
        """
        Retrieves all memory units.

        Returns:
            List[MemoryUnit]: A list of all memory units.
        """
        pass

    @abstractmethod
    def delete_all(self) -> None:
        """
        Deletes all memory units.
        """
        pass

    @abstractmethod
    def load(self) -> None:
        """
        Loads the memory from file. The path is specified in config.
        """
        pass

    @abstractmethod
    def save(self) -> None:
        """
        Save the memory to database or file. The path is specified in config.
        """
        pass

    @abstractmethod
    def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:
        """
        Filters memory units based on the given criteria.

        Args:
            criteria (Dict[str, Any]): A dictionary of filter conditions.

        Returns:
            List[MemoryUnit]: A list of memory units matching the criteria.
        """
        pass

    @abstractmethod
    def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:
        """
        Groups memory units into a meaningful group.

        Args:
            unit_ids (List[str]): A list of memory unit IDs to group.
            organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').
            metadata (Optional[Dict[str, Any]]): Additional metadata for the group.

        Returns:
            str: A unique identifier for the created group.
        """
        pass

    # @abstractmethod
    # def derive(self, unit_ids: List[str], derivation_type: str, **kwargs) -> MemoryUnit:
    #     """
    #     Derives a new memory unit from existing ones.

    #     Args:
    #         unit_ids (List[str]): A list of memory unit IDs to derive from.
    #         derivation_type (str): The type of derivation (e.g., 'summary', 'transformation').
    #         **kwargs: Additional parameters for the derivation process.

    #     Returns:
    #         MemoryUnit: The derived memory unit.
    #     """
    #     pass

    @abstractmethod
    def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:
        """
        Structures memory units into a knowledge graph or other structures.

        Args:
            unit_ids (List[str]): A list of memory unit IDs to structurize.
            structure_type (str): The type of structure (e.g., 'knowledge_graph').
            **kwargs: Additional parameters for the structuring process.
        """
        pass

    @abstractmethod
    def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:
        """
        Converts memory units into a reusable skill.

        Args:
            unit_ids (List[str]): A list of memory unit IDs to skillize.
            skill_name (str): The name of the skill to create.
            **kwargs: Additional parameters for skill creation.

        Returns:
            str: The unique identifier of the created skill.
        """
        pass

    @abstractmethod
    def embed(self, unit_id: str) -> None:
        """
        Generates an embedding for a memory unit.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        pass

    @abstractmethod
    def parameterize(self, **kwargs) -> None:
        """
        Trains a parametric model using the memory data.

        Args:
            **kwargs: Additional parameters for the training process.
        """
        pass

    @abstractmethod
    def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:
        """
        Asynchronously retrieve data from memory based on a query.

        Args:
            query (Any): The query or criteria to retrieve specific memory data.
            retrieve_type (str): The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').
            **kwargs: Additional parameters for the structuring process.

        Returns:
            Any: The retrieved memory data.

        Raises:
            RetrievalError: If the retrieval process fails.
        """
        pass

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during memory operations.

        This method can be overridden to implement custom error handling logic.

        Args:
            error (Exception): The exception that was raised.
        """
        # Default error handling: log the error
        print(f"Memory system encountered an error: {error}")
__init__(config)

Initialize the Memory system with the provided configuration.

Parameters:

Name Type Description Default
config Any

Configuration settings for the Memory system.

required
Source code in src/aeiva/cognition/memory/memory.py
16
17
18
19
20
21
22
23
def __init__(self, config: Any):
    """
    Initialize the Memory system with the provided configuration.

    Args:
        config (Any): Configuration settings for the Memory system.
    """
    self.config = config
create(content, **kwargs) abstractmethod

Creates a new memory unit with the given content and metadata.

Parameters:

Name Type Description Default
content Any

The core content of the memory unit.

required
**kwargs

Additional metadata for the memory unit.

{}

Returns:

Name Type Description
MemoryUnit MemoryUnit

The created memory unit.

Source code in src/aeiva/cognition/memory/memory.py
37
38
39
40
41
42
43
44
45
46
47
48
49
@abstractmethod
def create(self, content: Any, **kwargs) -> MemoryUnit:
    """
    Creates a new memory unit with the given content and metadata.

    Args:
        content (Any): The core content of the memory unit.
        **kwargs: Additional metadata for the memory unit.

    Returns:
        MemoryUnit: The created memory unit.
    """
    pass
delete(unit_id) abstractmethod

Deletes a memory unit by its unique identifier.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
Source code in src/aeiva/cognition/memory/memory.py
75
76
77
78
79
80
81
82
83
@abstractmethod
def delete(self, unit_id: str) -> None:
    """
    Deletes a memory unit by its unique identifier.

    Args:
        unit_id (str): The unique identifier of the memory unit.
    """
    pass
delete_all() abstractmethod

Deletes all memory units.

Source code in src/aeiva/cognition/memory/memory.py
 95
 96
 97
 98
 99
100
@abstractmethod
def delete_all(self) -> None:
    """
    Deletes all memory units.
    """
    pass
embed(unit_id) abstractmethod

Generates an embedding for a memory unit.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
Source code in src/aeiva/cognition/memory/memory.py
186
187
188
189
190
191
192
193
194
@abstractmethod
def embed(self, unit_id: str) -> None:
    """
    Generates an embedding for a memory unit.

    Args:
        unit_id (str): The unique identifier of the memory unit.
    """
    pass
filter(criteria) abstractmethod

Filters memory units based on the given criteria.

Parameters:

Name Type Description Default
criteria Dict[str, Any]

A dictionary of filter conditions.

required

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of memory units matching the criteria.

Source code in src/aeiva/cognition/memory/memory.py
116
117
118
119
120
121
122
123
124
125
126
127
@abstractmethod
def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:
    """
    Filters memory units based on the given criteria.

    Args:
        criteria (Dict[str, Any]): A dictionary of filter conditions.

    Returns:
        List[MemoryUnit]: A list of memory units matching the criteria.
    """
    pass
get(unit_id) abstractmethod

Retrieves a memory unit by its unique identifier.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required

Returns:

Name Type Description
MemoryUnit MemoryUnit

The retrieved memory unit.

Source code in src/aeiva/cognition/memory/memory.py
51
52
53
54
55
56
57
58
59
60
61
62
@abstractmethod
def get(self, unit_id: str) -> MemoryUnit:
    """
    Retrieves a memory unit by its unique identifier.

    Args:
        unit_id (str): The unique identifier of the memory unit.

    Returns:
        MemoryUnit: The retrieved memory unit.
    """
    pass
get_all() abstractmethod

Retrieves all memory units.

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of all memory units.

Source code in src/aeiva/cognition/memory/memory.py
85
86
87
88
89
90
91
92
93
@abstractmethod
def get_all(self) -> List[MemoryUnit]:
    """
    Retrieves all memory units.

    Returns:
        List[MemoryUnit]: A list of all memory units.
    """
    pass
handle_error(error)

Handle errors that occur during memory operations.

This method can be overridden to implement custom error handling logic.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/cognition/memory/memory.py
224
225
226
227
228
229
230
231
232
233
234
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during memory operations.

    This method can be overridden to implement custom error handling logic.

    Args:
        error (Exception): The exception that was raised.
    """
    # Default error handling: log the error
    print(f"Memory system encountered an error: {error}")
load() abstractmethod

Loads the memory from file. The path is specified in config.

Source code in src/aeiva/cognition/memory/memory.py
102
103
104
105
106
107
@abstractmethod
def load(self) -> None:
    """
    Loads the memory from file. The path is specified in config.
    """
    pass
organize(unit_ids, organize_type, metadata=None) abstractmethod

Groups memory units into a meaningful group.

Parameters:

Name Type Description Default
unit_ids List[str]

A list of memory unit IDs to group.

required
organize_type str

The type of group (e.g., 'dialogue_session', 'procedure').

required
metadata Optional[Dict[str, Any]]

Additional metadata for the group.

None

Returns:

Name Type Description
str str

A unique identifier for the created group.

Source code in src/aeiva/cognition/memory/memory.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@abstractmethod
def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:
    """
    Groups memory units into a meaningful group.

    Args:
        unit_ids (List[str]): A list of memory unit IDs to group.
        organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').
        metadata (Optional[Dict[str, Any]]): Additional metadata for the group.

    Returns:
        str: A unique identifier for the created group.
    """
    pass
parameterize(**kwargs) abstractmethod

Trains a parametric model using the memory data.

Parameters:

Name Type Description Default
**kwargs

Additional parameters for the training process.

{}
Source code in src/aeiva/cognition/memory/memory.py
196
197
198
199
200
201
202
203
204
@abstractmethod
def parameterize(self, **kwargs) -> None:
    """
    Trains a parametric model using the memory data.

    Args:
        **kwargs: Additional parameters for the training process.
    """
    pass
retrieve(query, retrieve_type, **kwargs) abstractmethod

Asynchronously retrieve data from memory based on a query.

Parameters:

Name Type Description Default
query Any

The query or criteria to retrieve specific memory data.

required
retrieve_type str

The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').

required
**kwargs

Additional parameters for the structuring process.

{}

Returns:

Name Type Description
Any List[MemoryUnit]

The retrieved memory data.

Raises:

Type Description
RetrievalError

If the retrieval process fails.

Source code in src/aeiva/cognition/memory/memory.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@abstractmethod
def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:
    """
    Asynchronously retrieve data from memory based on a query.

    Args:
        query (Any): The query or criteria to retrieve specific memory data.
        retrieve_type (str): The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').
        **kwargs: Additional parameters for the structuring process.

    Returns:
        Any: The retrieved memory data.

    Raises:
        RetrievalError: If the retrieval process fails.
    """
    pass
save() abstractmethod

Save the memory to database or file. The path is specified in config.

Source code in src/aeiva/cognition/memory/memory.py
109
110
111
112
113
114
@abstractmethod
def save(self) -> None:
    """
    Save the memory to database or file. The path is specified in config.
    """
    pass
setup() abstractmethod

Asynchronously set up the Memory system's components.

This method should initialize any necessary components or resources based on the provided configuration.

Raises:

Type Description
ConfigurationError

If the configuration is invalid or incomplete.

Source code in src/aeiva/cognition/memory/memory.py
25
26
27
28
29
30
31
32
33
34
35
@abstractmethod
def setup(self) -> None:
    """
    Asynchronously set up the Memory system's components.

    This method should initialize any necessary components or resources based on the provided configuration.

    Raises:
        ConfigurationError: If the configuration is invalid or incomplete.
    """
    pass
skillize(unit_ids, skill_name, **kwargs) abstractmethod

Converts memory units into a reusable skill.

Parameters:

Name Type Description Default
unit_ids List[str]

A list of memory unit IDs to skillize.

required
skill_name str

The name of the skill to create.

required
**kwargs

Additional parameters for skill creation.

{}

Returns:

Name Type Description
str str

The unique identifier of the created skill.

Source code in src/aeiva/cognition/memory/memory.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
@abstractmethod
def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:
    """
    Converts memory units into a reusable skill.

    Args:
        unit_ids (List[str]): A list of memory unit IDs to skillize.
        skill_name (str): The name of the skill to create.
        **kwargs: Additional parameters for skill creation.

    Returns:
        str: The unique identifier of the created skill.
    """
    pass
structurize(unit_ids, structure_type, **kwargs) abstractmethod

Structures memory units into a knowledge graph or other structures.

Parameters:

Name Type Description Default
unit_ids List[str]

A list of memory unit IDs to structurize.

required
structure_type str

The type of structure (e.g., 'knowledge_graph').

required
**kwargs

Additional parameters for the structuring process.

{}
Source code in src/aeiva/cognition/memory/memory.py
159
160
161
162
163
164
165
166
167
168
169
@abstractmethod
def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:
    """
    Structures memory units into a knowledge graph or other structures.

    Args:
        unit_ids (List[str]): A list of memory unit IDs to structurize.
        structure_type (str): The type of structure (e.g., 'knowledge_graph').
        **kwargs: Additional parameters for the structuring process.
    """
    pass
update(unit_id, updates) abstractmethod

Updates a memory unit with the given updates.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
updates Dict[str, Any]

A dictionary of fields to update.

required
Source code in src/aeiva/cognition/memory/memory.py
64
65
66
67
68
69
70
71
72
73
@abstractmethod
def update(self, unit_id: str, updates: Dict[str, Any]) -> None:
    """
    Updates a memory unit with the given updates.

    Args:
        unit_id (str): The unique identifier of the memory unit.
        updates (Dict[str, Any]): A dictionary of fields to update.
    """
    pass

memory_cleaner

MemoryCleaner

A class to clean memory units based on various filtering algorithms.

Supported filter types
  • 'time': Removes memory units older than a specified threshold.
  • 'modality': Keeps only memory units matching specified modalities.
  • 'type': Keeps only memory units matching specified types.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class MemoryCleaner:
    """
    A class to clean memory units based on various filtering algorithms.

    Supported filter types:
        - 'time': Removes memory units older than a specified threshold.
        - 'modality': Keeps only memory units matching specified modalities.
        - 'type': Keeps only memory units matching specified types.
    """

    def __init__(self):
        """
        Initializes the MemoryCleaner.

        Currently, no initialization parameters are required.
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.debug("Initialized MemoryCleaner without default parameters.")

    def filter(
        self,
        memory_units: List[MemoryUnit],
        filter_type: str,
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Filters the provided memory units based on the specified filter type.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be filtered.
            filter_type (str): The type of filtering algorithm to use ('time', 'modality', 'type').
            **kwargs: Additional parameters required for specific filters.
                For 'time' filter:
                    - threshold_days (int): Number of days beyond which memory units are removed.
                For 'modality' filter:
                    - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).
                For 'type' filter:
                    - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).

        Returns:
            List[MemoryUnit]: The list of memory units after filtering.

        Raises:
            MemoryCleanerError: If an unknown filter_type is provided or if required parameters are missing.
        """
        self.logger.debug(f"Filtering memory units using filter_type='{filter_type}' with kwargs={kwargs}")
        try:
            if filter_type == 'time':
                threshold_days = kwargs.get('threshold_days')
                if threshold_days is None:
                    self.logger.error("Missing 'threshold_days' parameter for time-based filtering.")
                    raise MemoryCleanerError("Missing 'threshold_days' parameter for time-based filtering.")
                return self.filter_by_time(memory_units, threshold_days)
            elif filter_type == 'modality':
                modalities = kwargs.get('modalities')
                if not modalities:
                    self.logger.error("Missing 'modalities' parameter for modality-based filtering.")
                    raise MemoryCleanerError("Missing 'modalities' parameter for modality-based filtering.")
                return self.filter_by_modality(memory_units, modalities)
            elif filter_type == 'type':
                types = kwargs.get('types')
                if not types:
                    self.logger.error("Missing 'types' parameter for type-based filtering.")
                    raise MemoryCleanerError("Missing 'types' parameter for type-based filtering.")
                return self.filter_by_type(memory_units, types)
            else:
                self.logger.error(f"Unknown filter_type: {filter_type}")
                raise MemoryCleanerError(f"Unknown filter_type: {filter_type}")
        except MemoryCleanerError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Failed to filter memory units: {e}")
            raise MemoryCleanerError(f"Failed to filter memory units: {e}")
    # TODO: more filter options

    def filter_by_time(self, memory_units: List[MemoryUnit], threshold_days: int) -> List[MemoryUnit]:
        """
        Removes memory units older than the specified threshold_days.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be filtered.
            threshold_days (int): Number of days beyond which memory units are removed.

        Returns:
            List[MemoryUnit]: The list of memory units after time-based filtering.

        Raises:
            MemoryCleanerError: If filtering fails.
        """
        self.logger.debug(f"Applying time-based filtering with threshold_days={threshold_days}")
        try:
            current_time = datetime.now(UTC)
            threshold = timedelta(days=threshold_days)
            filtered_memory = [
                mu for mu in memory_units
                if (current_time - mu.timestamp) <= threshold
            ]
            removed_count = len(memory_units) - len(filtered_memory)
            self.logger.info(
                f"Time-based filter: Removed {removed_count} memory units older than {threshold_days} days."
            )
            return filtered_memory
        except Exception as e:
            self.logger.error(f"Time-based filtering failed: {e}")
            raise MemoryCleanerError(f"Time-based filtering failed: {e}")

    def filter_by_modality(self, memory_units: List[MemoryUnit], modalities: List[str]) -> List[MemoryUnit]:
        """
        Keeps only memory units that match the specified modalities.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be filtered.
            modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).

        Returns:
            List[MemoryUnit]: The list of memory units after modality-based filtering.

        Raises:
            MemoryCleanerError: If filtering fails.
        """
        self.logger.debug(f"Applying modality-based filtering with modalities={modalities}")
        try:
            if not modalities:
                self.logger.warning("No modalities specified for modality-based filtering. Returning original memory units.")
                return memory_units

            filtered_memory = [
                mu for mu in memory_units
                if mu.modality in modalities
            ]
            removed_count = len(memory_units) - len(filtered_memory)
            self.logger.info(
                f"Modality-based filter: Removed {removed_count} memory units not in modalities {modalities}."
            )
            return filtered_memory
        except Exception as e:
            self.logger.error(f"Modality-based filtering failed: {e}")
            raise MemoryCleanerError(f"Modality-based filtering failed: {e}")

    def filter_by_type(self, memory_units: List[MemoryUnit], types: List[str]) -> List[MemoryUnit]:
        """
        Keeps only memory units that match the specified types.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be filtered.
            types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).

        Returns:
            List[MemoryUnit]: The list of memory units after type-based filtering.

        Raises:
            MemoryCleanerError: If filtering fails.
        """
        self.logger.debug(f"Applying type-based filtering with types={types}")
        try:
            if not types:
                self.logger.warning("No types specified for type-based filtering. Returning original memory units.")
                return memory_units

            filtered_memory = [
                mu for mu in memory_units
                if mu.type in types
            ]
            removed_count = len(memory_units) - len(filtered_memory)
            self.logger.info(
                f"Type-based filter: Removed {removed_count} memory units not in types {types}."
            )
            return filtered_memory
        except Exception as e:
            self.logger.error(f"Type-based filtering failed: {e}")
            raise MemoryCleanerError(f"Type-based filtering failed: {e}")
__init__()

Initializes the MemoryCleaner.

Currently, no initialization parameters are required.

Source code in src/aeiva/cognition/memory/memory_cleaner.py
26
27
28
29
30
31
32
33
def __init__(self):
    """
    Initializes the MemoryCleaner.

    Currently, no initialization parameters are required.
    """
    self.logger = logging.getLogger(self.__class__.__name__)
    self.logger.debug("Initialized MemoryCleaner without default parameters.")
filter(memory_units, filter_type, **kwargs)

Filters the provided memory units based on the specified filter type.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be filtered.

required
filter_type str

The type of filtering algorithm to use ('time', 'modality', 'type').

required
**kwargs

Additional parameters required for specific filters. For 'time' filter: - threshold_days (int): Number of days beyond which memory units are removed. For 'modality' filter: - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']). For 'type' filter: - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after filtering.

Raises:

Type Description
MemoryCleanerError

If an unknown filter_type is provided or if required parameters are missing.

Source code in src/aeiva/cognition/memory/memory_cleaner.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def filter(
    self,
    memory_units: List[MemoryUnit],
    filter_type: str,
    **kwargs
) -> List[MemoryUnit]:
    """
    Filters the provided memory units based on the specified filter type.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be filtered.
        filter_type (str): The type of filtering algorithm to use ('time', 'modality', 'type').
        **kwargs: Additional parameters required for specific filters.
            For 'time' filter:
                - threshold_days (int): Number of days beyond which memory units are removed.
            For 'modality' filter:
                - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).
            For 'type' filter:
                - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).

    Returns:
        List[MemoryUnit]: The list of memory units after filtering.

    Raises:
        MemoryCleanerError: If an unknown filter_type is provided or if required parameters are missing.
    """
    self.logger.debug(f"Filtering memory units using filter_type='{filter_type}' with kwargs={kwargs}")
    try:
        if filter_type == 'time':
            threshold_days = kwargs.get('threshold_days')
            if threshold_days is None:
                self.logger.error("Missing 'threshold_days' parameter for time-based filtering.")
                raise MemoryCleanerError("Missing 'threshold_days' parameter for time-based filtering.")
            return self.filter_by_time(memory_units, threshold_days)
        elif filter_type == 'modality':
            modalities = kwargs.get('modalities')
            if not modalities:
                self.logger.error("Missing 'modalities' parameter for modality-based filtering.")
                raise MemoryCleanerError("Missing 'modalities' parameter for modality-based filtering.")
            return self.filter_by_modality(memory_units, modalities)
        elif filter_type == 'type':
            types = kwargs.get('types')
            if not types:
                self.logger.error("Missing 'types' parameter for type-based filtering.")
                raise MemoryCleanerError("Missing 'types' parameter for type-based filtering.")
            return self.filter_by_type(memory_units, types)
        else:
            self.logger.error(f"Unknown filter_type: {filter_type}")
            raise MemoryCleanerError(f"Unknown filter_type: {filter_type}")
    except MemoryCleanerError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Failed to filter memory units: {e}")
        raise MemoryCleanerError(f"Failed to filter memory units: {e}")
filter_by_modality(memory_units, modalities)

Keeps only memory units that match the specified modalities.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be filtered.

required
modalities List[str]

List of modalities to retain (e.g., ['text', 'image']).

required

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after modality-based filtering.

Raises:

Type Description
MemoryCleanerError

If filtering fails.

Source code in src/aeiva/cognition/memory/memory_cleaner.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def filter_by_modality(self, memory_units: List[MemoryUnit], modalities: List[str]) -> List[MemoryUnit]:
    """
    Keeps only memory units that match the specified modalities.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be filtered.
        modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).

    Returns:
        List[MemoryUnit]: The list of memory units after modality-based filtering.

    Raises:
        MemoryCleanerError: If filtering fails.
    """
    self.logger.debug(f"Applying modality-based filtering with modalities={modalities}")
    try:
        if not modalities:
            self.logger.warning("No modalities specified for modality-based filtering. Returning original memory units.")
            return memory_units

        filtered_memory = [
            mu for mu in memory_units
            if mu.modality in modalities
        ]
        removed_count = len(memory_units) - len(filtered_memory)
        self.logger.info(
            f"Modality-based filter: Removed {removed_count} memory units not in modalities {modalities}."
        )
        return filtered_memory
    except Exception as e:
        self.logger.error(f"Modality-based filtering failed: {e}")
        raise MemoryCleanerError(f"Modality-based filtering failed: {e}")
filter_by_time(memory_units, threshold_days)

Removes memory units older than the specified threshold_days.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be filtered.

required
threshold_days int

Number of days beyond which memory units are removed.

required

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after time-based filtering.

Raises:

Type Description
MemoryCleanerError

If filtering fails.

Source code in src/aeiva/cognition/memory/memory_cleaner.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def filter_by_time(self, memory_units: List[MemoryUnit], threshold_days: int) -> List[MemoryUnit]:
    """
    Removes memory units older than the specified threshold_days.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be filtered.
        threshold_days (int): Number of days beyond which memory units are removed.

    Returns:
        List[MemoryUnit]: The list of memory units after time-based filtering.

    Raises:
        MemoryCleanerError: If filtering fails.
    """
    self.logger.debug(f"Applying time-based filtering with threshold_days={threshold_days}")
    try:
        current_time = datetime.now(UTC)
        threshold = timedelta(days=threshold_days)
        filtered_memory = [
            mu for mu in memory_units
            if (current_time - mu.timestamp) <= threshold
        ]
        removed_count = len(memory_units) - len(filtered_memory)
        self.logger.info(
            f"Time-based filter: Removed {removed_count} memory units older than {threshold_days} days."
        )
        return filtered_memory
    except Exception as e:
        self.logger.error(f"Time-based filtering failed: {e}")
        raise MemoryCleanerError(f"Time-based filtering failed: {e}")
filter_by_type(memory_units, types)

Keeps only memory units that match the specified types.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be filtered.

required
types List[str]

List of types to retain (e.g., ['dialogue', 'summary']).

required

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after type-based filtering.

Raises:

Type Description
MemoryCleanerError

If filtering fails.

Source code in src/aeiva/cognition/memory/memory_cleaner.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def filter_by_type(self, memory_units: List[MemoryUnit], types: List[str]) -> List[MemoryUnit]:
    """
    Keeps only memory units that match the specified types.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be filtered.
        types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).

    Returns:
        List[MemoryUnit]: The list of memory units after type-based filtering.

    Raises:
        MemoryCleanerError: If filtering fails.
    """
    self.logger.debug(f"Applying type-based filtering with types={types}")
    try:
        if not types:
            self.logger.warning("No types specified for type-based filtering. Returning original memory units.")
            return memory_units

        filtered_memory = [
            mu for mu in memory_units
            if mu.type in types
        ]
        removed_count = len(memory_units) - len(filtered_memory)
        self.logger.info(
            f"Type-based filter: Removed {removed_count} memory units not in types {types}."
        )
        return filtered_memory
    except Exception as e:
        self.logger.error(f"Type-based filtering failed: {e}")
        raise MemoryCleanerError(f"Type-based filtering failed: {e}")
MemoryCleanerError

Bases: Exception

Exception raised when an error occurs in the MemoryCleaner.

Source code in src/aeiva/cognition/memory/memory_cleaner.py
11
12
13
class MemoryCleanerError(Exception):
    """Exception raised when an error occurs in the MemoryCleaner."""
    pass

memory_config

MemoryConfig dataclass

Bases: BaseConfig

Configuration class for the Memory system.

Attributes:

Name Type Description
embedder_config EmbedderConfig

Configuration for the embedding model.

storage_config StorageConfig

Configuration for the storage system.

Source code in src/aeiva/cognition/memory/memory_config.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@dataclass
class MemoryConfig(BaseConfig):
    """
    Configuration class for the Memory system.

    Attributes:
        embedder_config (EmbedderConfig): Configuration for the embedding model.
        storage_config (StorageConfig): Configuration for the storage system.
    """

    embedder_config: EmbedderConfig = field(
        metadata={"help": "Configuration for the embedding model."}
    )
    storage_config: StorageConfig = field(
        metadata={"help": "Configuration for the storage system."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Perform any necessary validation
        if not self.embedder_config:
            raise ValueError("Embedder configuration must be provided.")
        if not self.storage_config:
            raise ValueError("Storage configuration must be provided.")

Bases: BaseModel

MemoryLink represents a relationship between two memory units, allowing complex structures to be built by linking individual memory units.

Attributes:

Name Type Description
id str

Unique identifier for the edge, generated as a UUID string by default.

source_id str

Unique identifier of the source memory unit.

target_id str

Unique identifier of the target memory unit.

relationship str

Type of relationship between memory units, such as 'causal' or 'association'.

metadata Optional[Dict[str, Any]]

Additional metadata for the edge.

Source code in src/aeiva/cognition/memory/memory_link.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MemoryLink(BaseModel):
    """
    MemoryLink represents a relationship between two memory units, allowing
    complex structures to be built by linking individual memory units.

    Attributes:
        id (str): Unique identifier for the edge, generated as a UUID string by default.
        source_id (str): Unique identifier of the source memory unit.
        target_id (str): Unique identifier of the target memory unit.
        relationship (str): Type of relationship between memory units, such as 'causal' or 'association'.
        metadata (Optional[Dict[str, Any]]): Additional metadata for the edge.
    """
    id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the edge.")
    source_id: str = Field(..., description="Unique identifier of the source memory unit.")
    target_id: str = Field(..., description="Unique identifier of the target memory unit.")
    relationship: str = Field("", description="Type of relationship, e.g., 'causal', 'temporal'.")
    metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata for the edge.")

    def to_dict(self) -> dict:
        """Converts the MemoryLink instance to a dictionary format for serialization."""
        return self.dict()

    @classmethod
    def from_dict(cls, data: dict) -> "MemoryLink":
        """Creates a MemoryLink instance from a dictionary."""
        return cls(**data)
from_dict(data) classmethod

Creates a MemoryLink instance from a dictionary.

Source code in src/aeiva/cognition/memory/memory_link.py
27
28
29
30
@classmethod
def from_dict(cls, data: dict) -> "MemoryLink":
    """Creates a MemoryLink instance from a dictionary."""
    return cls(**data)
to_dict()

Converts the MemoryLink instance to a dictionary format for serialization.

Source code in src/aeiva/cognition/memory/memory_link.py
23
24
25
def to_dict(self) -> dict:
    """Converts the MemoryLink instance to a dictionary format for serialization."""
    return self.dict()

memory_organizer

MemoryOrganizer

A class to organize memory units based on various organizing algorithms.

Supported organize types
  • 'dialogue': Groups memory units by 'dialogue_session_id'.
Future organize types can be added here.
Source code in src/aeiva/cognition/memory/memory_organizer.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class MemoryOrganizer:
    """
    A class to organize memory units based on various organizing algorithms.

    Supported organize types:
        - 'dialogue': Groups memory units by 'dialogue_session_id'.
        # Future organize types can be added here.
    """

    def __init__(self):
        """
        Initializes the MemoryOrganizer.

        Currently, no initialization parameters are required.
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.debug("Initialized MemoryOrganizer without default parameters.")

    def organize(
        self,
        memory_units: List[MemoryUnit],
        organize_type: str,
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Organizes the provided memory units based on the specified organize type.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be organized.
            organize_type (str): The type of organizing algorithm to use ('dialogue').
            **kwargs: Additional parameters required for specific organizers.
                For 'dialogue' organize:
                    - group_field (str): The metadata field to group by (default: 'dialogue_session_id').
                    - derive_content (bool): Whether to derive content for the group (default: True).
                    - derivation_type (str): The type of derivation to perform ('summary', etc.).

        Returns:
            List[MemoryUnit]: The list of memory units after organizing.

        Raises:
            MemoryOrganizerError: If an unknown organize_type is provided or if required parameters are missing.
        """
        self.logger.debug(f"Organizing memory units using organize_type='{organize_type}' with kwargs={kwargs}")
        try:
            if organize_type == 'dialogue':
                group_field = kwargs.get('group_field', 'dialogue_session_id')
                derive_content = kwargs.get('derive_content', True)
                derivation_type = kwargs.get('derivation_type', 'summary')
                return self.organize_by_dialogue(memory_units, group_field, derive_content, derivation_type)
            else:
                self.logger.error(f"Unknown organize_type: {organize_type}")
                raise MemoryOrganizerError(f"Unknown organize_type: {organize_type}")
        except MemoryOrganizerError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Failed to organize memory units: {e}")
            raise MemoryOrganizerError(f"Failed to organize memory units: {e}")

    def organize_by_dialogue(
        self,
        memory_units: List[MemoryUnit],
        group_field: str = 'dialogue_session_id',  # NOTE: here we assume the meta data field of dialogue memory units has a dialogue_session_id
        derive_content: bool = False,
        derivation_type: str = 'summary'
    ) -> List[MemoryUnit]:
        """
        Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be organized.
            group_field (str): The metadata field to group by (default: 'dialogue_session_id').
            derive_content (bool): Whether to derive content for the group (default: True).
            derivation_type (str): The type of derivation to perform ('summary', etc.).

        Returns:
            List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.

        Raises:
            MemoryOrganizerError: If organizing fails.
        """
        self.logger.debug(f"Organizing by dialogue with group_field='{group_field}', derive_content={derive_content}, derivation_type='{derivation_type}'")
        try:
            # Group memory units by the specified group_field
            groups = defaultdict(list)
            for mu in memory_units:
                group_id = mu.metadata.get(group_field)
                if group_id:
                    groups[group_id].append(mu)
                else:
                    self.logger.debug(f"MemoryUnit '{mu.id}' does not have '{group_field}'. Skipping grouping.")

            self.logger.info(f"Found {len(groups)} dialogue groups based on '{group_field}'.")

            # Create new MemoryUnit for each group
            new_memory_units = []
            for group_id, group_mus in groups.items():
                self.logger.debug(f"Creating DialogueGroup for group_id='{group_id}' with {len(group_mus)} memory units.")

                # Create a new MemoryUnit to represent the DialogueGroup
                dialogue_group = MemoryUnit(
                    content="",  # Content to be derived
                    type="dialogue_session",
                    metadata={
                        "organized_at": datetime.now(timezone.utc).isoformat(),
                        "member_ids": [mu.id for mu in group_mus],
                        "derivation_type": derivation_type
                    }
                )

                # Link each memory unit to the DialogueGroup
                for mu in group_mus:
                    link = MemoryLink(
                        source_id=mu.id,
                        target_id=dialogue_group.id,
                        relationship='part_of'
                    )
                    mu.edges.append(link)
                    self.logger.debug(f"Linked MemoryUnit '{mu.id}' to DialogueGroup '{dialogue_group.id}'.")

                # Optionally, derive content for the group
                if derive_content:
                    if derivation_type == 'summary':
                        derived_content = self.derive_summary(group_mus)
                    elif derivation_type == 'reflection':
                        derived_content = self.derive_reflection(group_mus)
                    else:
                        self.logger.warning(f"Unknown derivation_type '{derivation_type}'. Skipping content derivation.")
                        derived_content = ""
                    dialogue_group.content = derived_content
                    dialogue_group.status = 'derived'
                    self.logger.debug(f"Derived content for DialogueGroup '{dialogue_group.id}': {derived_content}")

                new_memory_units.append(dialogue_group)
                self.logger.info(f"DialogueGroup '{dialogue_group.id}' created for group_id='{group_id}'.")

            # Return the original memory units plus the new dialogue groups
            organized_memory = memory_units + new_memory_units
            self.logger.debug(f"Organizing by dialogue completed. Total memory units after organizing: {len(organized_memory)}")
            return organized_memory

        except Exception as e:
            self.logger.error(f"Error organizing by dialogue: {e}")
            raise MemoryOrganizerError(f"Error organizing by dialogue: {e}")

    def derive_summary(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation
        """
        Derives a summary from the given memory units.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to summarize.

        Returns:
            str: A summary string.
        """
        self.logger.debug(f"Deriving summary from {len(memory_units)} memory units.")
        try:
            summary = "Summary of dialogue session:\n"
            for mu in memory_units:
                summary += f"- {mu.content}\n"
            derived_summary = summary.strip()
            self.logger.debug(f"Derived summary: {derived_summary}")
            return derived_summary
        except Exception as e:
            self.logger.error(f"Failed to derive summary: {e}")
            raise MemoryOrganizerError(f"Failed to derive summary: {e}")

    def derive_reflection(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation
        """
        Derives a reflection from the given memory units.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to reflect upon.

        Returns:
            str: A reflection string.
        """
        self.logger.debug(f"Deriving reflection from {len(memory_units)} memory units.")
        try:
            reflection = "Reflection on dialogue session:\n"
            for mu in memory_units:
                reflection += f"- {mu.content}\n"
            derived_reflection = reflection.strip()
            self.logger.debug(f"Derived reflection: {derived_reflection}")
            return derived_reflection
        except Exception as e:
            self.logger.error(f"Failed to derive reflection: {e}")
            raise MemoryOrganizerError(f"Failed to derive reflection: {e}")
__init__()

Initializes the MemoryOrganizer.

Currently, no initialization parameters are required.

Source code in src/aeiva/cognition/memory/memory_organizer.py
26
27
28
29
30
31
32
33
def __init__(self):
    """
    Initializes the MemoryOrganizer.

    Currently, no initialization parameters are required.
    """
    self.logger = logging.getLogger(self.__class__.__name__)
    self.logger.debug("Initialized MemoryOrganizer without default parameters.")
derive_reflection(memory_units)

Derives a reflection from the given memory units.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to reflect upon.

required

Returns:

Name Type Description
str str

A reflection string.

Source code in src/aeiva/cognition/memory/memory_organizer.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def derive_reflection(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation
    """
    Derives a reflection from the given memory units.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to reflect upon.

    Returns:
        str: A reflection string.
    """
    self.logger.debug(f"Deriving reflection from {len(memory_units)} memory units.")
    try:
        reflection = "Reflection on dialogue session:\n"
        for mu in memory_units:
            reflection += f"- {mu.content}\n"
        derived_reflection = reflection.strip()
        self.logger.debug(f"Derived reflection: {derived_reflection}")
        return derived_reflection
    except Exception as e:
        self.logger.error(f"Failed to derive reflection: {e}")
        raise MemoryOrganizerError(f"Failed to derive reflection: {e}")
derive_summary(memory_units)

Derives a summary from the given memory units.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to summarize.

required

Returns:

Name Type Description
str str

A summary string.

Source code in src/aeiva/cognition/memory/memory_organizer.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def derive_summary(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation
    """
    Derives a summary from the given memory units.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to summarize.

    Returns:
        str: A summary string.
    """
    self.logger.debug(f"Deriving summary from {len(memory_units)} memory units.")
    try:
        summary = "Summary of dialogue session:\n"
        for mu in memory_units:
            summary += f"- {mu.content}\n"
        derived_summary = summary.strip()
        self.logger.debug(f"Derived summary: {derived_summary}")
        return derived_summary
    except Exception as e:
        self.logger.error(f"Failed to derive summary: {e}")
        raise MemoryOrganizerError(f"Failed to derive summary: {e}")
organize(memory_units, organize_type, **kwargs)

Organizes the provided memory units based on the specified organize type.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be organized.

required
organize_type str

The type of organizing algorithm to use ('dialogue').

required
**kwargs

Additional parameters required for specific organizers. For 'dialogue' organize: - group_field (str): The metadata field to group by (default: 'dialogue_session_id'). - derive_content (bool): Whether to derive content for the group (default: True). - derivation_type (str): The type of derivation to perform ('summary', etc.).

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after organizing.

Raises:

Type Description
MemoryOrganizerError

If an unknown organize_type is provided or if required parameters are missing.

Source code in src/aeiva/cognition/memory/memory_organizer.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def organize(
    self,
    memory_units: List[MemoryUnit],
    organize_type: str,
    **kwargs
) -> List[MemoryUnit]:
    """
    Organizes the provided memory units based on the specified organize type.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be organized.
        organize_type (str): The type of organizing algorithm to use ('dialogue').
        **kwargs: Additional parameters required for specific organizers.
            For 'dialogue' organize:
                - group_field (str): The metadata field to group by (default: 'dialogue_session_id').
                - derive_content (bool): Whether to derive content for the group (default: True).
                - derivation_type (str): The type of derivation to perform ('summary', etc.).

    Returns:
        List[MemoryUnit]: The list of memory units after organizing.

    Raises:
        MemoryOrganizerError: If an unknown organize_type is provided or if required parameters are missing.
    """
    self.logger.debug(f"Organizing memory units using organize_type='{organize_type}' with kwargs={kwargs}")
    try:
        if organize_type == 'dialogue':
            group_field = kwargs.get('group_field', 'dialogue_session_id')
            derive_content = kwargs.get('derive_content', True)
            derivation_type = kwargs.get('derivation_type', 'summary')
            return self.organize_by_dialogue(memory_units, group_field, derive_content, derivation_type)
        else:
            self.logger.error(f"Unknown organize_type: {organize_type}")
            raise MemoryOrganizerError(f"Unknown organize_type: {organize_type}")
    except MemoryOrganizerError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Failed to organize memory units: {e}")
        raise MemoryOrganizerError(f"Failed to organize memory units: {e}")
organize_by_dialogue(memory_units, group_field='dialogue_session_id', derive_content=False, derivation_type='summary')

Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be organized.

required
group_field str

The metadata field to group by (default: 'dialogue_session_id').

'dialogue_session_id'
derive_content bool

Whether to derive content for the group (default: True).

False
derivation_type str

The type of derivation to perform ('summary', etc.).

'summary'

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.

Raises:

Type Description
MemoryOrganizerError

If organizing fails.

Source code in src/aeiva/cognition/memory/memory_organizer.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def organize_by_dialogue(
    self,
    memory_units: List[MemoryUnit],
    group_field: str = 'dialogue_session_id',  # NOTE: here we assume the meta data field of dialogue memory units has a dialogue_session_id
    derive_content: bool = False,
    derivation_type: str = 'summary'
) -> List[MemoryUnit]:
    """
    Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be organized.
        group_field (str): The metadata field to group by (default: 'dialogue_session_id').
        derive_content (bool): Whether to derive content for the group (default: True).
        derivation_type (str): The type of derivation to perform ('summary', etc.).

    Returns:
        List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.

    Raises:
        MemoryOrganizerError: If organizing fails.
    """
    self.logger.debug(f"Organizing by dialogue with group_field='{group_field}', derive_content={derive_content}, derivation_type='{derivation_type}'")
    try:
        # Group memory units by the specified group_field
        groups = defaultdict(list)
        for mu in memory_units:
            group_id = mu.metadata.get(group_field)
            if group_id:
                groups[group_id].append(mu)
            else:
                self.logger.debug(f"MemoryUnit '{mu.id}' does not have '{group_field}'. Skipping grouping.")

        self.logger.info(f"Found {len(groups)} dialogue groups based on '{group_field}'.")

        # Create new MemoryUnit for each group
        new_memory_units = []
        for group_id, group_mus in groups.items():
            self.logger.debug(f"Creating DialogueGroup for group_id='{group_id}' with {len(group_mus)} memory units.")

            # Create a new MemoryUnit to represent the DialogueGroup
            dialogue_group = MemoryUnit(
                content="",  # Content to be derived
                type="dialogue_session",
                metadata={
                    "organized_at": datetime.now(timezone.utc).isoformat(),
                    "member_ids": [mu.id for mu in group_mus],
                    "derivation_type": derivation_type
                }
            )

            # Link each memory unit to the DialogueGroup
            for mu in group_mus:
                link = MemoryLink(
                    source_id=mu.id,
                    target_id=dialogue_group.id,
                    relationship='part_of'
                )
                mu.edges.append(link)
                self.logger.debug(f"Linked MemoryUnit '{mu.id}' to DialogueGroup '{dialogue_group.id}'.")

            # Optionally, derive content for the group
            if derive_content:
                if derivation_type == 'summary':
                    derived_content = self.derive_summary(group_mus)
                elif derivation_type == 'reflection':
                    derived_content = self.derive_reflection(group_mus)
                else:
                    self.logger.warning(f"Unknown derivation_type '{derivation_type}'. Skipping content derivation.")
                    derived_content = ""
                dialogue_group.content = derived_content
                dialogue_group.status = 'derived'
                self.logger.debug(f"Derived content for DialogueGroup '{dialogue_group.id}': {derived_content}")

            new_memory_units.append(dialogue_group)
            self.logger.info(f"DialogueGroup '{dialogue_group.id}' created for group_id='{group_id}'.")

        # Return the original memory units plus the new dialogue groups
        organized_memory = memory_units + new_memory_units
        self.logger.debug(f"Organizing by dialogue completed. Total memory units after organizing: {len(organized_memory)}")
        return organized_memory

    except Exception as e:
        self.logger.error(f"Error organizing by dialogue: {e}")
        raise MemoryOrganizerError(f"Error organizing by dialogue: {e}")
MemoryOrganizerError

Bases: Exception

Exception raised when an error occurs in the MemoryOrganizer.

Source code in src/aeiva/cognition/memory/memory_organizer.py
12
13
14
class MemoryOrganizerError(Exception):
    """Exception raised when an error occurs in the MemoryOrganizer."""
    pass

memory_palace

MemoryPalace

Bases: Memory

Concrete implementation of the Memory abstract base class.

This class provides methods to manage memory units, including creation, retrieval, updating, deletion, filtering, grouping, structurizing, skillizing, parameterizing, and more. It delegates specific operations to specialized components like MemoryCleaner, MemoryOrganizer, MemoryRetriever, MemoryStructurer, MemorySkillizer, and MemoryParameterizer.

Source code in src/aeiva/cognition/memory/memory_palace.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
class MemoryPalace(Memory):
    """
    Concrete implementation of the Memory abstract base class.

    This class provides methods to manage memory units, including creation, retrieval,
    updating, deletion, filtering, grouping, structurizing, skillizing, parameterizing,
    and more. It delegates specific operations to specialized components like
    MemoryCleaner, MemoryOrganizer, MemoryRetriever, MemoryStructurer, MemorySkillizer,
    and MemoryParameterizer.
    """

    def __init__(self, config: Dict):
        """
        Initialize the MemoryPalace with the provided configuration.

        Args:
            config (MemoryConfig): Configuration settings for the MemoryPalace.
        """
        self.config_dict = config
        self.config = None
        self.storage = None
        self.embedder = None
        self.cleaner = None
        self.organizer = None
        self.retriever = None
        self.structurer = None
        self.skillizer = None
        self.parameterizer = None
        self.setup()

    def setup(self):
        """
        Setup the MemoryPalace by initializing all components.
        """
        try:
            # Initialize EmbedderConfig
            embedder_config_dict = self.config_dict.get('embedder_config', {})
            self.embedder = Embedder(embedder_config_dict)

            storage_config_dict = self.config_dict.get('storage_config', {})
            self.storage = MemoryStorage(storage_config_dict) 

            # Initialize Memory Configuration
            self.config = MemoryConfig(
                embedder_config=self.embedder.config,
                storage_config=self.storage.config
            )

            logger.info("MemoryPalace: MemoryStorage and Embedder initialized successfully.")

            # Initialize specialized components
            self.cleaner = MemoryCleaner()
            self.organizer = MemoryOrganizer()
            self.retriever = MemoryRetriever(embedder=self.embedder, storage=self.storage)
            self.structurer = MemoryStructurer()
            self.skillizer = MemorySkillizer()
            self.parameterizer = MemoryParameterizer()
            logger.info("MemoryPalace: Specialized components initialized successfully.")

        except Exception as e:
            logger.error(f"MemoryPalace setup failed: {e}")
            self.handle_error(e)
            raise

    # CRUD Operations

    def create(self, content: Any, **kwargs) -> MemoryUnit:
        """
        Creates a new memory unit with the given content and metadata.

        Args:
            content (Any): The core content of the memory unit.
            **kwargs: Additional metadata for the memory unit.

        Returns:
            MemoryUnit: The created memory unit.
        """
        try:
            # Instantiate MemoryUnit
            memory_unit = MemoryUnit(content=content, **kwargs)

            # Generate embedding
            embedding_response = self.embedder.embed(content)
            if embedding_response.get("data"):
                memory_unit.embedding = embedding_response["data"][0].get("embedding")
            else:
                raise ValueError("Failed to generate embedding for the content.")

            # Delegate storage operations to MemoryStorage
            self.storage.add_memory_unit(memory_unit)

            logger.info(f"Created new MemoryUnit with ID: {memory_unit.id}")
            return memory_unit
        except Exception as e:
            logger.error(f"Error creating MemoryUnit: {e}")
            self.handle_error(e)
            raise

    def get(self, unit_id: str) -> MemoryUnit:
        """
        Retrieves a memory unit by its unique identifier.

        Args:
            unit_id (str): The unique identifier of the memory unit.

        Returns:
            MemoryUnit: The retrieved memory unit.
        """
        try:
            memory_unit = self.storage.get_memory_unit(unit_id)
            logger.info(f"Retrieved MemoryUnit with ID: {unit_id}")
            return memory_unit
        except Exception as e:
            logger.error(f"Error retrieving MemoryUnit with ID {unit_id}: {e}")
            self.handle_error(e)
            raise

    def update(self, unit_id: str, updates: Dict[str, Any]) -> None:
        """
        Updates a memory unit with the given updates.

        Args:
            unit_id (str): The unique identifier of the memory unit.
            updates (Dict[str, Any]): A dictionary of fields to update.
        """
        try:
            # Delegate update operations to MemoryStorage
            self.storage.update_memory_unit(unit_id, updates)
            logger.info(f"Updated MemoryUnit with ID: {unit_id}")
        except Exception as e:
            logger.error(f"Error updating MemoryUnit with ID {unit_id}: {e}")
            self.handle_error(e)
            raise

    def delete(self, unit_id: str) -> None:
        """
        Deletes a memory unit by its unique identifier.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        try:
            # Delegate deletion to MemoryStorage
            self.storage.delete_memory_unit(unit_id)
            logger.info(f"Deleted MemoryUnit with ID: {unit_id}")
        except Exception as e:
            logger.error(f"Error deleting MemoryUnit with ID {unit_id}: {e}")
            self.handle_error(e)
            raise

    def get_all(self) -> List[MemoryUnit]:
        """
        Retrieves all memory units.

        Returns:
            List[MemoryUnit]: A list of all memory units.
        """
        try:
            memory_units = self.storage.get_all_memory_units()
            logger.info(f"Retrieved all MemoryUnits. Total count: {len(memory_units)}")
            return memory_units
        except Exception as e:
            logger.error(f"Error retrieving all MemoryUnits: {e}")
            self.handle_error(e)
            raise

    def delete_all(self) -> None:
        """
        Deletes all memory units.
        """
        try:
            self.storage.delete_all_memory_units()  # TODO: seems no work correctly, need to check
            logger.info("Deleted all MemoryUnits.")
        except Exception as e:
            logger.error(f"Error deleting all MemoryUnits: {e}")
            self.handle_error(e)
            raise

    def load(self) -> List[MemoryUnit]:
        """
        Loads all memory units from the storage.

        Returns:
            List[MemoryUnit]: A list of all loaded memory units.
        """
        try:
            # Retrieve all memory units from storage
            memory_units = self.get_all()
            logger.info(f"Loaded {len(memory_units)} MemoryUnits from storage.")
            return memory_units
        except Exception as e:
            logger.error(f"Error loading MemoryUnits: {e}")
            self.handle_error(e)
            raise

    def save(self, export_path: Optional[str] = None) -> None:
        """
        Saves all memory units to the storage or exports them to a specified path.

        Args:
            export_path (Optional[str]): The file path to export memory units as JSON.
                                        If None, saves are handled by MemoryStorage.
        """
        try:
            if export_path:
                # Export memory units to a JSON file
                memory_units = self.get_all()
                export_data = [mu.to_dict() for mu in memory_units]
                with open(export_path, 'w', encoding='utf-8') as f:
                    json.dump(export_data, f, ensure_ascii=False, indent=4)
                logger.info(f"Exported {len(memory_units)} MemoryUnits to {export_path}.")
            else:
                # If no export path is provided, assume that MemoryStorage handles persistence
                logger.info("Save operation delegated to MemoryStorage.")
                # Example: self.storage.persist_changes()
        except Exception as e:
            logger.error(f"Error saving MemoryUnits: {e}")
            self.handle_error(e)
            raise

    # Delegated Operations

    def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:
        """
        Filters memory units based on the given criteria.

        Args:
            criteria (Dict[str, Any]): A dictionary of filter conditions.

        Returns:
            List[MemoryUnit]: A list of memory units matching the criteria.
        """
        try:
            memory_units = self.get_all()
            filter_type = criteria.get('filter_type')
            if not filter_type:
                raise ValueError("Missing 'filter_type' in criteria.")

            # Delegate filtering to MemoryCleaner
            filtered_memories = self.cleaner.filter(memory_units, filter_type, **criteria)
            logger.info(f"Filtered memories based on criteria: {criteria}")
            return filtered_memories
        except Exception as e:
            logger.error(f"Error filtering memories: {e}")
            self.handle_error(e)
            raise

    def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:
        """
        Groups memory units into a meaningful group.

        Args:
            unit_ids (List[str]): A list of memory unit IDs to group.
            organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').
            metadata (Optional[Dict[str, Any]]): Additional metadata for the group.

        Returns:
            str: A unique identifier for the created group.
        """
        try:
            # Retrieve the memory units to group
            memory_units = [self.get(unit_id) for unit_id in unit_ids]
            logger.debug(f"Grouping {len(memory_units)} MemoryUnits into group_type='{organize_type}'.")

            # Delegate grouping to MemoryOrganizer
            organized_memories = self.organizer.organize(memory_units, organize_type, metadata=metadata)
            logger.info(f"Grouped memories into '{organize_type}'. Total memory units after grouping: {len(organized_memories)}")
            return "group_id_placeholder"  # Replace with actual group ID if applicable
        except Exception as e:
            logger.error(f"Error grouping memories: {e}")
            self.handle_error(e)
            raise

    def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:
        """
        Structures memory units into a knowledge graph or other structures.

        Args:
            unit_ids (List[str]): A list of memory unit IDs to structurize.
            structure_type (str): The type of structure (e.g., 'knowledge_graph').
            **kwargs: Additional parameters for the structuring process.
        """
        try:
            # Retrieve the memory units to structurize
            memory_units = [self.get(uid) for uid in unit_ids]
            logger.debug(f"Structurizing {len(memory_units)} MemoryUnits with structure_type='{structure_type}'.")

            # Delegate structuring to MemoryStructurer
            self.structurer.structure(memory_units, structure_type, **kwargs)
            logger.info(f"Structurized memories with structure_type='{structure_type}'.")
        except Exception as e:
            logger.error(f"Error structurizing memories: {e}")
            self.handle_error(e)
            raise

    def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:
        """
        Converts memory units into a reusable skill.

        Args:
            unit_ids (List[str]): A list of memory unit IDs to skillize.
            skill_name (str): The name of the skill to create.
            **kwargs: Additional parameters for skill creation.

        Returns:
            str: The unique identifier of the created skill.
        """
        try:
            # Retrieve the memory units to skillize
            memory_units = [self.get(uid) for uid in unit_ids]
            logger.debug(f"Skillizing {len(memory_units)} MemoryUnits into skill_name='{skill_name}'.")

            # Delegate skillizing to MemorySkillizer
            skill_id = self.skillizer.skillize(memory_units, skill_name, **kwargs)
            logger.info(f"Skillized memories into skill with ID: {skill_id}")
            return skill_id
        except Exception as e:
            logger.error(f"Error skillizing memories: {e}")
            self.handle_error(e)
            raise

    def parameterize(self, **kwargs) -> None:
        """
        Trains a parametric model using the memory data.

        Args:
            **kwargs: Additional parameters for the training process.
        """
        try:
            # Retrieve all memory units
            memory_units = self.get_all()
            logger.debug(f"Parameterizing {len(memory_units)} MemoryUnits.")

            # Delegate parameterizing to MemoryParameterizer
            self.parameterizer.parameterize(memory_units, **kwargs)
            logger.info("Parameterized memories successfully.")
        except Exception as e:
            logger.error(f"Error parameterizing memories: {e}")
            self.handle_error(e)
            raise

    def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:
        """
        Retrieve data from memory based on a query.

        Args:
            query (Any): The query or criteria to retrieve specific memory data.
            retrieve_type (str): The type of retrieval (e.g., 'similar', 'related').
            **kwargs: Additional parameters for the retrieval process.

        Returns:
            List[MemoryUnit]: The retrieved memory data.
        """
        try:
            # Delegate retrieval to MemoryRetriever
            memories = self.retriever.retrieve(query=query, retrieve_type=retrieve_type, **kwargs)
            logger.info(f"Retrieved {len(memories)} memories using retrieve_type='{retrieve_type}'.")
            return memories
        except Exception as e:
            logger.error(f"Error retrieving MemoryUnits: {e}")
            self.handle_error(e)
            raise

    def embed(self, unit_id: str) -> None:
        """
        Generates an embedding for a memory unit.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        try:
            # Delegate embedding to MemoryRetriever
            memory_units = self.retriever.retrieve(query=unit_id, retrieve_type='similar', top_k=1)
            if not memory_units:
                raise ValueError(f"No MemoryUnit found with ID {unit_id} to embed.")

            memory_unit = memory_units[0]

            # Generate embedding using the embedder
            embedding_response = self.embedder.embed(memory_unit.content)
            if embedding_response.get("data") and len(embedding_response["data"]) > 0:
                memory_unit.embedding = embedding_response["data"][0].get("embedding")
            else:
                raise ValueError("Failed to generate embedding for the content.")

            # Update the memory unit with the new embedding
            self.update(unit_id, {'embedding': memory_unit.embedding})

            logger.info(f"Generated embedding for MemoryUnit ID: {unit_id}")
        except Exception as e:
            logger.error(f"Error generating embedding for MemoryUnit ID {unit_id}: {e}")
            self.handle_error(e)
            raise

    # Error Handling

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during memory operations.

        Args:
            error (Exception): The exception that was raised.
        """
        logger.error(f"MemoryPalace encountered an error: {error}")
        # Additional error handling can be implemented here

    @staticmethod
    def get_api_key(self, config_section: Dict[str, Any], key_field: str, env_var_field: str) -> Optional[str]:
        """
        Retrieve an API key from the configuration section.

        Args:
            config_section (Dict[str, Any]): The configuration section (e.g., embedder_config).
            key_field (str): The key in the config_section that may contain the API key directly.
            env_var_field (str): The key in the config_section that specifies the environment variable name.

        Returns:
            Optional[str]: The API key if found, else None.

        Raises:
            EnvironmentError: If the environment variable is specified but not set.
        """
        # Check if API key is provided directly
        api_key = config_section.get(key_field)
        if api_key:
            logger.info(f"Using provided API key for '{key_field}'.")
            return api_key

        # Else, check if an environment variable is specified
        env_var = config_section.get(env_var_field)
        if env_var:
            api_key = os.getenv(env_var)
            if api_key:
                logger.info(f"Retrieved API key for '{key_field}' from environment variable '{env_var}'.")
                return api_key
            else:
                logger.error(f"Environment variable '{env_var}' for '{key_field}' is not set.")
                raise EnvironmentError(f"Environment variable '{env_var}' for '{key_field}' is not set.")

        logger.warning(f"No API key provided for '{key_field}'.")
        return None
__init__(config)

Initialize the MemoryPalace with the provided configuration.

Parameters:

Name Type Description Default
config MemoryConfig

Configuration settings for the MemoryPalace.

required
Source code in src/aeiva/cognition/memory/memory_palace.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(self, config: Dict):
    """
    Initialize the MemoryPalace with the provided configuration.

    Args:
        config (MemoryConfig): Configuration settings for the MemoryPalace.
    """
    self.config_dict = config
    self.config = None
    self.storage = None
    self.embedder = None
    self.cleaner = None
    self.organizer = None
    self.retriever = None
    self.structurer = None
    self.skillizer = None
    self.parameterizer = None
    self.setup()
create(content, **kwargs)

Creates a new memory unit with the given content and metadata.

Parameters:

Name Type Description Default
content Any

The core content of the memory unit.

required
**kwargs

Additional metadata for the memory unit.

{}

Returns:

Name Type Description
MemoryUnit MemoryUnit

The created memory unit.

Source code in src/aeiva/cognition/memory/memory_palace.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def create(self, content: Any, **kwargs) -> MemoryUnit:
    """
    Creates a new memory unit with the given content and metadata.

    Args:
        content (Any): The core content of the memory unit.
        **kwargs: Additional metadata for the memory unit.

    Returns:
        MemoryUnit: The created memory unit.
    """
    try:
        # Instantiate MemoryUnit
        memory_unit = MemoryUnit(content=content, **kwargs)

        # Generate embedding
        embedding_response = self.embedder.embed(content)
        if embedding_response.get("data"):
            memory_unit.embedding = embedding_response["data"][0].get("embedding")
        else:
            raise ValueError("Failed to generate embedding for the content.")

        # Delegate storage operations to MemoryStorage
        self.storage.add_memory_unit(memory_unit)

        logger.info(f"Created new MemoryUnit with ID: {memory_unit.id}")
        return memory_unit
    except Exception as e:
        logger.error(f"Error creating MemoryUnit: {e}")
        self.handle_error(e)
        raise
delete(unit_id)

Deletes a memory unit by its unique identifier.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
Source code in src/aeiva/cognition/memory/memory_palace.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def delete(self, unit_id: str) -> None:
    """
    Deletes a memory unit by its unique identifier.

    Args:
        unit_id (str): The unique identifier of the memory unit.
    """
    try:
        # Delegate deletion to MemoryStorage
        self.storage.delete_memory_unit(unit_id)
        logger.info(f"Deleted MemoryUnit with ID: {unit_id}")
    except Exception as e:
        logger.error(f"Error deleting MemoryUnit with ID {unit_id}: {e}")
        self.handle_error(e)
        raise
delete_all()

Deletes all memory units.

Source code in src/aeiva/cognition/memory/memory_palace.py
197
198
199
200
201
202
203
204
205
206
207
def delete_all(self) -> None:
    """
    Deletes all memory units.
    """
    try:
        self.storage.delete_all_memory_units()  # TODO: seems no work correctly, need to check
        logger.info("Deleted all MemoryUnits.")
    except Exception as e:
        logger.error(f"Error deleting all MemoryUnits: {e}")
        self.handle_error(e)
        raise
embed(unit_id)

Generates an embedding for a memory unit.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
Source code in src/aeiva/cognition/memory/memory_palace.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def embed(self, unit_id: str) -> None:
    """
    Generates an embedding for a memory unit.

    Args:
        unit_id (str): The unique identifier of the memory unit.
    """
    try:
        # Delegate embedding to MemoryRetriever
        memory_units = self.retriever.retrieve(query=unit_id, retrieve_type='similar', top_k=1)
        if not memory_units:
            raise ValueError(f"No MemoryUnit found with ID {unit_id} to embed.")

        memory_unit = memory_units[0]

        # Generate embedding using the embedder
        embedding_response = self.embedder.embed(memory_unit.content)
        if embedding_response.get("data") and len(embedding_response["data"]) > 0:
            memory_unit.embedding = embedding_response["data"][0].get("embedding")
        else:
            raise ValueError("Failed to generate embedding for the content.")

        # Update the memory unit with the new embedding
        self.update(unit_id, {'embedding': memory_unit.embedding})

        logger.info(f"Generated embedding for MemoryUnit ID: {unit_id}")
    except Exception as e:
        logger.error(f"Error generating embedding for MemoryUnit ID {unit_id}: {e}")
        self.handle_error(e)
        raise
filter(criteria)

Filters memory units based on the given criteria.

Parameters:

Name Type Description Default
criteria Dict[str, Any]

A dictionary of filter conditions.

required

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of memory units matching the criteria.

Source code in src/aeiva/cognition/memory/memory_palace.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:
    """
    Filters memory units based on the given criteria.

    Args:
        criteria (Dict[str, Any]): A dictionary of filter conditions.

    Returns:
        List[MemoryUnit]: A list of memory units matching the criteria.
    """
    try:
        memory_units = self.get_all()
        filter_type = criteria.get('filter_type')
        if not filter_type:
            raise ValueError("Missing 'filter_type' in criteria.")

        # Delegate filtering to MemoryCleaner
        filtered_memories = self.cleaner.filter(memory_units, filter_type, **criteria)
        logger.info(f"Filtered memories based on criteria: {criteria}")
        return filtered_memories
    except Exception as e:
        logger.error(f"Error filtering memories: {e}")
        self.handle_error(e)
        raise
get(unit_id)

Retrieves a memory unit by its unique identifier.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required

Returns:

Name Type Description
MemoryUnit MemoryUnit

The retrieved memory unit.

Source code in src/aeiva/cognition/memory/memory_palace.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def get(self, unit_id: str) -> MemoryUnit:
    """
    Retrieves a memory unit by its unique identifier.

    Args:
        unit_id (str): The unique identifier of the memory unit.

    Returns:
        MemoryUnit: The retrieved memory unit.
    """
    try:
        memory_unit = self.storage.get_memory_unit(unit_id)
        logger.info(f"Retrieved MemoryUnit with ID: {unit_id}")
        return memory_unit
    except Exception as e:
        logger.error(f"Error retrieving MemoryUnit with ID {unit_id}: {e}")
        self.handle_error(e)
        raise
get_all()

Retrieves all memory units.

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of all memory units.

Source code in src/aeiva/cognition/memory/memory_palace.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def get_all(self) -> List[MemoryUnit]:
    """
    Retrieves all memory units.

    Returns:
        List[MemoryUnit]: A list of all memory units.
    """
    try:
        memory_units = self.storage.get_all_memory_units()
        logger.info(f"Retrieved all MemoryUnits. Total count: {len(memory_units)}")
        return memory_units
    except Exception as e:
        logger.error(f"Error retrieving all MemoryUnits: {e}")
        self.handle_error(e)
        raise
get_api_key(config_section, key_field, env_var_field) staticmethod

Retrieve an API key from the configuration section.

Parameters:

Name Type Description Default
config_section Dict[str, Any]

The configuration section (e.g., embedder_config).

required
key_field str

The key in the config_section that may contain the API key directly.

required
env_var_field str

The key in the config_section that specifies the environment variable name.

required

Returns:

Type Description
Optional[str]

Optional[str]: The API key if found, else None.

Raises:

Type Description
EnvironmentError

If the environment variable is specified but not set.

Source code in src/aeiva/cognition/memory/memory_palace.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
@staticmethod
def get_api_key(self, config_section: Dict[str, Any], key_field: str, env_var_field: str) -> Optional[str]:
    """
    Retrieve an API key from the configuration section.

    Args:
        config_section (Dict[str, Any]): The configuration section (e.g., embedder_config).
        key_field (str): The key in the config_section that may contain the API key directly.
        env_var_field (str): The key in the config_section that specifies the environment variable name.

    Returns:
        Optional[str]: The API key if found, else None.

    Raises:
        EnvironmentError: If the environment variable is specified but not set.
    """
    # Check if API key is provided directly
    api_key = config_section.get(key_field)
    if api_key:
        logger.info(f"Using provided API key for '{key_field}'.")
        return api_key

    # Else, check if an environment variable is specified
    env_var = config_section.get(env_var_field)
    if env_var:
        api_key = os.getenv(env_var)
        if api_key:
            logger.info(f"Retrieved API key for '{key_field}' from environment variable '{env_var}'.")
            return api_key
        else:
            logger.error(f"Environment variable '{env_var}' for '{key_field}' is not set.")
            raise EnvironmentError(f"Environment variable '{env_var}' for '{key_field}' is not set.")

    logger.warning(f"No API key provided for '{key_field}'.")
    return None
handle_error(error)

Handle errors that occur during memory operations.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/cognition/memory/memory_palace.py
427
428
429
430
431
432
433
434
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during memory operations.

    Args:
        error (Exception): The exception that was raised.
    """
    logger.error(f"MemoryPalace encountered an error: {error}")
load()

Loads all memory units from the storage.

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of all loaded memory units.

Source code in src/aeiva/cognition/memory/memory_palace.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def load(self) -> List[MemoryUnit]:
    """
    Loads all memory units from the storage.

    Returns:
        List[MemoryUnit]: A list of all loaded memory units.
    """
    try:
        # Retrieve all memory units from storage
        memory_units = self.get_all()
        logger.info(f"Loaded {len(memory_units)} MemoryUnits from storage.")
        return memory_units
    except Exception as e:
        logger.error(f"Error loading MemoryUnits: {e}")
        self.handle_error(e)
        raise
organize(unit_ids, organize_type, metadata=None)

Groups memory units into a meaningful group.

Parameters:

Name Type Description Default
unit_ids List[str]

A list of memory unit IDs to group.

required
organize_type str

The type of group (e.g., 'dialogue_session', 'procedure').

required
metadata Optional[Dict[str, Any]]

Additional metadata for the group.

None

Returns:

Name Type Description
str str

A unique identifier for the created group.

Source code in src/aeiva/cognition/memory/memory_palace.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:
    """
    Groups memory units into a meaningful group.

    Args:
        unit_ids (List[str]): A list of memory unit IDs to group.
        organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').
        metadata (Optional[Dict[str, Any]]): Additional metadata for the group.

    Returns:
        str: A unique identifier for the created group.
    """
    try:
        # Retrieve the memory units to group
        memory_units = [self.get(unit_id) for unit_id in unit_ids]
        logger.debug(f"Grouping {len(memory_units)} MemoryUnits into group_type='{organize_type}'.")

        # Delegate grouping to MemoryOrganizer
        organized_memories = self.organizer.organize(memory_units, organize_type, metadata=metadata)
        logger.info(f"Grouped memories into '{organize_type}'. Total memory units after grouping: {len(organized_memories)}")
        return "group_id_placeholder"  # Replace with actual group ID if applicable
    except Exception as e:
        logger.error(f"Error grouping memories: {e}")
        self.handle_error(e)
        raise
parameterize(**kwargs)

Trains a parametric model using the memory data.

Parameters:

Name Type Description Default
**kwargs

Additional parameters for the training process.

{}
Source code in src/aeiva/cognition/memory/memory_palace.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def parameterize(self, **kwargs) -> None:
    """
    Trains a parametric model using the memory data.

    Args:
        **kwargs: Additional parameters for the training process.
    """
    try:
        # Retrieve all memory units
        memory_units = self.get_all()
        logger.debug(f"Parameterizing {len(memory_units)} MemoryUnits.")

        # Delegate parameterizing to MemoryParameterizer
        self.parameterizer.parameterize(memory_units, **kwargs)
        logger.info("Parameterized memories successfully.")
    except Exception as e:
        logger.error(f"Error parameterizing memories: {e}")
        self.handle_error(e)
        raise
retrieve(query, retrieve_type, **kwargs)

Retrieve data from memory based on a query.

Parameters:

Name Type Description Default
query Any

The query or criteria to retrieve specific memory data.

required
retrieve_type str

The type of retrieval (e.g., 'similar', 'related').

required
**kwargs

Additional parameters for the retrieval process.

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The retrieved memory data.

Source code in src/aeiva/cognition/memory/memory_palace.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:
    """
    Retrieve data from memory based on a query.

    Args:
        query (Any): The query or criteria to retrieve specific memory data.
        retrieve_type (str): The type of retrieval (e.g., 'similar', 'related').
        **kwargs: Additional parameters for the retrieval process.

    Returns:
        List[MemoryUnit]: The retrieved memory data.
    """
    try:
        # Delegate retrieval to MemoryRetriever
        memories = self.retriever.retrieve(query=query, retrieve_type=retrieve_type, **kwargs)
        logger.info(f"Retrieved {len(memories)} memories using retrieve_type='{retrieve_type}'.")
        return memories
    except Exception as e:
        logger.error(f"Error retrieving MemoryUnits: {e}")
        self.handle_error(e)
        raise
save(export_path=None)

Saves all memory units to the storage or exports them to a specified path.

Parameters:

Name Type Description Default
export_path Optional[str]

The file path to export memory units as JSON. If None, saves are handled by MemoryStorage.

None
Source code in src/aeiva/cognition/memory/memory_palace.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def save(self, export_path: Optional[str] = None) -> None:
    """
    Saves all memory units to the storage or exports them to a specified path.

    Args:
        export_path (Optional[str]): The file path to export memory units as JSON.
                                    If None, saves are handled by MemoryStorage.
    """
    try:
        if export_path:
            # Export memory units to a JSON file
            memory_units = self.get_all()
            export_data = [mu.to_dict() for mu in memory_units]
            with open(export_path, 'w', encoding='utf-8') as f:
                json.dump(export_data, f, ensure_ascii=False, indent=4)
            logger.info(f"Exported {len(memory_units)} MemoryUnits to {export_path}.")
        else:
            # If no export path is provided, assume that MemoryStorage handles persistence
            logger.info("Save operation delegated to MemoryStorage.")
            # Example: self.storage.persist_changes()
    except Exception as e:
        logger.error(f"Error saving MemoryUnits: {e}")
        self.handle_error(e)
        raise
setup()

Setup the MemoryPalace by initializing all components.

Source code in src/aeiva/cognition/memory/memory_palace.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def setup(self):
    """
    Setup the MemoryPalace by initializing all components.
    """
    try:
        # Initialize EmbedderConfig
        embedder_config_dict = self.config_dict.get('embedder_config', {})
        self.embedder = Embedder(embedder_config_dict)

        storage_config_dict = self.config_dict.get('storage_config', {})
        self.storage = MemoryStorage(storage_config_dict) 

        # Initialize Memory Configuration
        self.config = MemoryConfig(
            embedder_config=self.embedder.config,
            storage_config=self.storage.config
        )

        logger.info("MemoryPalace: MemoryStorage and Embedder initialized successfully.")

        # Initialize specialized components
        self.cleaner = MemoryCleaner()
        self.organizer = MemoryOrganizer()
        self.retriever = MemoryRetriever(embedder=self.embedder, storage=self.storage)
        self.structurer = MemoryStructurer()
        self.skillizer = MemorySkillizer()
        self.parameterizer = MemoryParameterizer()
        logger.info("MemoryPalace: Specialized components initialized successfully.")

    except Exception as e:
        logger.error(f"MemoryPalace setup failed: {e}")
        self.handle_error(e)
        raise
skillize(unit_ids, skill_name, **kwargs)

Converts memory units into a reusable skill.

Parameters:

Name Type Description Default
unit_ids List[str]

A list of memory unit IDs to skillize.

required
skill_name str

The name of the skill to create.

required
**kwargs

Additional parameters for skill creation.

{}

Returns:

Name Type Description
str str

The unique identifier of the created skill.

Source code in src/aeiva/cognition/memory/memory_palace.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:
    """
    Converts memory units into a reusable skill.

    Args:
        unit_ids (List[str]): A list of memory unit IDs to skillize.
        skill_name (str): The name of the skill to create.
        **kwargs: Additional parameters for skill creation.

    Returns:
        str: The unique identifier of the created skill.
    """
    try:
        # Retrieve the memory units to skillize
        memory_units = [self.get(uid) for uid in unit_ids]
        logger.debug(f"Skillizing {len(memory_units)} MemoryUnits into skill_name='{skill_name}'.")

        # Delegate skillizing to MemorySkillizer
        skill_id = self.skillizer.skillize(memory_units, skill_name, **kwargs)
        logger.info(f"Skillized memories into skill with ID: {skill_id}")
        return skill_id
    except Exception as e:
        logger.error(f"Error skillizing memories: {e}")
        self.handle_error(e)
        raise
structurize(unit_ids, structure_type, **kwargs)

Structures memory units into a knowledge graph or other structures.

Parameters:

Name Type Description Default
unit_ids List[str]

A list of memory unit IDs to structurize.

required
structure_type str

The type of structure (e.g., 'knowledge_graph').

required
**kwargs

Additional parameters for the structuring process.

{}
Source code in src/aeiva/cognition/memory/memory_palace.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:
    """
    Structures memory units into a knowledge graph or other structures.

    Args:
        unit_ids (List[str]): A list of memory unit IDs to structurize.
        structure_type (str): The type of structure (e.g., 'knowledge_graph').
        **kwargs: Additional parameters for the structuring process.
    """
    try:
        # Retrieve the memory units to structurize
        memory_units = [self.get(uid) for uid in unit_ids]
        logger.debug(f"Structurizing {len(memory_units)} MemoryUnits with structure_type='{structure_type}'.")

        # Delegate structuring to MemoryStructurer
        self.structurer.structure(memory_units, structure_type, **kwargs)
        logger.info(f"Structurized memories with structure_type='{structure_type}'.")
    except Exception as e:
        logger.error(f"Error structurizing memories: {e}")
        self.handle_error(e)
        raise
update(unit_id, updates)

Updates a memory unit with the given updates.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
updates Dict[str, Any]

A dictionary of fields to update.

required
Source code in src/aeiva/cognition/memory/memory_palace.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def update(self, unit_id: str, updates: Dict[str, Any]) -> None:
    """
    Updates a memory unit with the given updates.

    Args:
        unit_id (str): The unique identifier of the memory unit.
        updates (Dict[str, Any]): A dictionary of fields to update.
    """
    try:
        # Delegate update operations to MemoryStorage
        self.storage.update_memory_unit(unit_id, updates)
        logger.info(f"Updated MemoryUnit with ID: {unit_id}")
    except Exception as e:
        logger.error(f"Error updating MemoryUnit with ID {unit_id}: {e}")
        self.handle_error(e)
        raise

memory_parameterizer

MemoryParameterizer

A class to parameterize memory units based on various parameterizing algorithms.

Supported parameterize types
  • 'parameterize_type_example': Placeholder for future parameterizing algorithms.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class MemoryParameterizer:
    """
    A class to parameterize memory units based on various parameterizing algorithms.

    Supported parameterize types:
        - 'parameterize_type_example': Placeholder for future parameterizing algorithms.
    """

    def __init__(self):
        """
        Initializes the MemoryParameterizer.

        Currently, no initialization parameters are required.
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.debug("Initialized MemoryParameterizer without default parameters.")

    def parameterize(
        self,
        memory_units: List[MemoryUnit],
        parameterize_type: str,
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Parameterizes the provided memory units based on the specified parameterize type.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be parameterized.
            parameterize_type (str): The type of parameterizing algorithm to use ('parameterize_type_example').
            **kwargs: Additional parameters required for specific parameterizers.

        Returns:
            List[MemoryUnit]: The list of memory units after parameterization.

        Raises:
            MemoryParameterizerError: If an unknown parameterize_type is provided or if parameterizing fails.
        """
        self.logger.debug(f"Parameterizing memory units using parameterize_type='{parameterize_type}' with kwargs={kwargs}")
        try:
            if parameterize_type == 'parameterize_type_example':
                # Placeholder for actual parameterizing logic
                return self.parameterize_example(memory_units, **kwargs)
            else:
                self.logger.error(f"Unknown parameterize_type: {parameterize_type}")
                raise MemoryParameterizerError(f"Unknown parameterize_type: {parameterize_type}")
        except MemoryParameterizerError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Failed to parameterize memory units: {e}")
            raise MemoryParameterizerError(f"Failed to parameterize memory units: {e}")

    def parameterize_example(
        self,
        memory_units: List[MemoryUnit],
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Example parameterizing method. Currently a placeholder that returns memory units unchanged.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be parameterized.
            **kwargs: Additional parameters (currently unused).

        Returns:
            List[MemoryUnit]: The original list of memory units, unchanged.
        """
        self.logger.debug("Executing parameterize_example: No changes applied to memory units.")
        # Placeholder: No operation performed
        return memory_units
__init__()

Initializes the MemoryParameterizer.

Currently, no initialization parameters are required.

Source code in src/aeiva/cognition/memory/memory_parameterizer.py
23
24
25
26
27
28
29
30
def __init__(self):
    """
    Initializes the MemoryParameterizer.

    Currently, no initialization parameters are required.
    """
    self.logger = logging.getLogger(self.__class__.__name__)
    self.logger.debug("Initialized MemoryParameterizer without default parameters.")
parameterize(memory_units, parameterize_type, **kwargs)

Parameterizes the provided memory units based on the specified parameterize type.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be parameterized.

required
parameterize_type str

The type of parameterizing algorithm to use ('parameterize_type_example').

required
**kwargs

Additional parameters required for specific parameterizers.

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after parameterization.

Raises:

Type Description
MemoryParameterizerError

If an unknown parameterize_type is provided or if parameterizing fails.

Source code in src/aeiva/cognition/memory/memory_parameterizer.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def parameterize(
    self,
    memory_units: List[MemoryUnit],
    parameterize_type: str,
    **kwargs
) -> List[MemoryUnit]:
    """
    Parameterizes the provided memory units based on the specified parameterize type.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be parameterized.
        parameterize_type (str): The type of parameterizing algorithm to use ('parameterize_type_example').
        **kwargs: Additional parameters required for specific parameterizers.

    Returns:
        List[MemoryUnit]: The list of memory units after parameterization.

    Raises:
        MemoryParameterizerError: If an unknown parameterize_type is provided or if parameterizing fails.
    """
    self.logger.debug(f"Parameterizing memory units using parameterize_type='{parameterize_type}' with kwargs={kwargs}")
    try:
        if parameterize_type == 'parameterize_type_example':
            # Placeholder for actual parameterizing logic
            return self.parameterize_example(memory_units, **kwargs)
        else:
            self.logger.error(f"Unknown parameterize_type: {parameterize_type}")
            raise MemoryParameterizerError(f"Unknown parameterize_type: {parameterize_type}")
    except MemoryParameterizerError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Failed to parameterize memory units: {e}")
        raise MemoryParameterizerError(f"Failed to parameterize memory units: {e}")
parameterize_example(memory_units, **kwargs)

Example parameterizing method. Currently a placeholder that returns memory units unchanged.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be parameterized.

required
**kwargs

Additional parameters (currently unused).

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The original list of memory units, unchanged.

Source code in src/aeiva/cognition/memory/memory_parameterizer.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def parameterize_example(
    self,
    memory_units: List[MemoryUnit],
    **kwargs
) -> List[MemoryUnit]:
    """
    Example parameterizing method. Currently a placeholder that returns memory units unchanged.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be parameterized.
        **kwargs: Additional parameters (currently unused).

    Returns:
        List[MemoryUnit]: The original list of memory units, unchanged.
    """
    self.logger.debug("Executing parameterize_example: No changes applied to memory units.")
    # Placeholder: No operation performed
    return memory_units
MemoryParameterizerError

Bases: Exception

Exception raised when an error occurs in the MemoryParameterizer.

Source code in src/aeiva/cognition/memory/memory_parameterizer.py
10
11
12
class MemoryParameterizerError(Exception):
    """Exception raised when an error occurs in the MemoryParameterizer."""
    pass

memory_retriever

MemoryRetriever

A class to retrieve memory units based on various retrieval algorithms.

Supported retrieval types
  • 'similar': Retrieves memory units similar to a given query based on embeddings.
  • 'related': Retrieves memory units related to a specified query based on relationships.
Source code in src/aeiva/cognition/memory/memory_retriever.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class MemoryRetriever:
    """
    A class to retrieve memory units based on various retrieval algorithms.

    Supported retrieval types:
        - 'similar': Retrieves memory units similar to a given query based on embeddings.
        - 'related': Retrieves memory units related to a specified query based on relationships.
    """

    def __init__(self, embedder: Embedder, storage: MemoryStorage):
        """
        Initializes the MemoryRetriever.

        Args:
            embedder (Embedder): An instance responsible for generating embeddings.
            storage (MemoryStorage): An instance managing data storage and retrieval.
        """
        self.embedder = embedder
        self.storage = storage
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.debug("Initialized MemoryRetriever with provided embedder and storage.")

    def retrieve(
        self,
        query: Any,
        retrieve_type: str,
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Factory method to retrieve memory units based on the specified retrieval type.

        Args:
            query (Any): The query for retrieval.
            retrieve_type (str): The type of retrieval ('similar' or 'related').
            **kwargs: Additional parameters required for specific retrieval types.
                For 'similar' retrieval:
                    - top_k (int): The number of similar units to retrieve.
                For 'related' retrieval:
                    - relationship (Optional[str]): The type of relationship to filter by.

        Returns:
            List[MemoryUnit]: A list of retrieved memory units.

        Raises:
            MemoryRetrieverError: If an unknown retrieval_type is provided or if retrieval fails.
        """
        self.logger.info(f"Initiating retrieval of type '{retrieve_type}' with query: {query}")
        try:
            if retrieve_type == 'similar':
                top_k = kwargs.get('top_k', 5)
                self.logger.debug(f"Retrieval Type: 'similar' with top_k={top_k}")
                return self.retrieve_similar(query, top_k)
            elif retrieve_type == 'related':
                relationship = kwargs.get('relationship')
                self.logger.debug(f"Retrieval Type: 'related' with relationship='{relationship}'")
                return self.retrieve_related(query, relationship)
            else:
                self.logger.error(f"Unknown retrieve_type: {retrieve_type}")
                raise MemoryRetrieverError(f"Unknown retrieve_type: {retrieve_type}")
        except MemoryRetrieverError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Failed to retrieve memory units: {e}")
            raise MemoryRetrieverError(f"Failed to retrieve memory units: {e}") from e

    def retrieve_similar(self, query: Any, top_k: int = 5) -> List[MemoryUnit]:
        """
        Retrieves memory units similar to the given input based on embeddings.

        Args:
            query (Any): The query for retrieval.
            top_k (int): The number of similar units to retrieve.

        Returns:
            List[MemoryUnit]: A list of similar memory units.

        Raises:
            MemoryRetrieverError: If retrieval fails due to embedding generation or storage issues.
        """
        self.logger.info(f"Retrieving top {top_k} similar MemoryUnits based on the query.")
        try:
            # Generate embedding for the query
            self.logger.debug("Generating embedding for the query.")
            embedding_response = self.embedder.embed(query)
            if not embedding_response.get("data"):
                self.logger.error("Failed to generate embedding for the query.")
                raise MemoryRetrieverError("Failed to generate embedding for the query.")

            query_embedding = embedding_response["data"][0].get("embedding")
            if not query_embedding:
                self.logger.error("Embedding data is missing in the response.")
                raise MemoryRetrieverError("Embedding data is missing in the response.")

            self.logger.debug(f"Embedding generated successfully: {query_embedding}")

            # Perform similarity search via MemoryStorage
            self.logger.debug("Performing similarity search in the vector database.")
            similar_units = self.storage.retrieve_similar_memory_units(query_embedding, top_k)
            self.logger.info(f"Retrieved {len(similar_units)} similar MemoryUnits.")
            return similar_units

        except MemoryRetrieverError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Unexpected error during retrieve_similar: {e}")
            raise MemoryRetrieverError(f"Unexpected error during retrieve_similar: {e}") from e

    def retrieve_related(
        self,
        query: Any,
        relationship: Optional[str] = None
    ) -> List[MemoryUnit]:  # TODO: revise the method later
        """
        Retrieves memory units related to the given query based on relationships.

        Args:
            query (Any): The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.
            relationship (Optional[str]): The type of relationship to filter by.

        Returns:
            List[MemoryUnit]: A list of related memory units.

        Raises:
            MemoryRetrieverError: If retrieval fails due to storage issues or invalid queries.
        """
        self.logger.info(f"Retrieving memories related to the query with relationship: {relationship}")
        try:
            # Assuming 'query' is a MemoryUnit ID or can be used to fetch a MemoryUnit
            self.logger.debug("Fetching the target MemoryUnit from storage.")
            target_memory_unit = self.storage.get_memory_unit(query)
            if not target_memory_unit:
                self.logger.error(f"MemoryUnit with ID '{query}' not found.")
                raise MemoryRetrieverError(f"MemoryUnit with ID '{query}' not found.")

            self.logger.debug(f"MemoryUnit fetched successfully: {target_memory_unit}")

            # Perform related retrieval via MemoryStorage
            self.logger.debug("Retrieving related MemoryUnits from the graph database.")
            related_units = self.storage.retrieve_related_memory_units(target_memory_unit.id, relationship)
            self.logger.info(f"Retrieved {len(related_units)} related MemoryUnits.")
            return related_units

        except MemoryRetrieverError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Unexpected error during retrieve_related: {e}")
            raise MemoryRetrieverError(f"Unexpected error during retrieve_related: {e}") from e

    def handle_error(self, error: Exception):
        """
        Handles errors by logging or performing other necessary actions.

        Args:
            error (Exception): The exception to handle.
        """
        # Implement any error handling logic here
        # For now, we'll just log the error
        self.logger.error(f"An error occurred: {error}")
__init__(embedder, storage)

Initializes the MemoryRetriever.

Parameters:

Name Type Description Default
embedder Embedder

An instance responsible for generating embeddings.

required
storage MemoryStorage

An instance managing data storage and retrieval.

required
Source code in src/aeiva/cognition/memory/memory_retriever.py
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(self, embedder: Embedder, storage: MemoryStorage):
    """
    Initializes the MemoryRetriever.

    Args:
        embedder (Embedder): An instance responsible for generating embeddings.
        storage (MemoryStorage): An instance managing data storage and retrieval.
    """
    self.embedder = embedder
    self.storage = storage
    self.logger = logging.getLogger(self.__class__.__name__)
    self.logger.debug("Initialized MemoryRetriever with provided embedder and storage.")
handle_error(error)

Handles errors by logging or performing other necessary actions.

Parameters:

Name Type Description Default
error Exception

The exception to handle.

required
Source code in src/aeiva/cognition/memory/memory_retriever.py
167
168
169
170
171
172
173
174
175
176
def handle_error(self, error: Exception):
    """
    Handles errors by logging or performing other necessary actions.

    Args:
        error (Exception): The exception to handle.
    """
    # Implement any error handling logic here
    # For now, we'll just log the error
    self.logger.error(f"An error occurred: {error}")
retrieve(query, retrieve_type, **kwargs)

Factory method to retrieve memory units based on the specified retrieval type.

Parameters:

Name Type Description Default
query Any

The query for retrieval.

required
retrieve_type str

The type of retrieval ('similar' or 'related').

required
**kwargs

Additional parameters required for specific retrieval types. For 'similar' retrieval: - top_k (int): The number of similar units to retrieve. For 'related' retrieval: - relationship (Optional[str]): The type of relationship to filter by.

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of retrieved memory units.

Raises:

Type Description
MemoryRetrieverError

If an unknown retrieval_type is provided or if retrieval fails.

Source code in src/aeiva/cognition/memory/memory_retriever.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def retrieve(
    self,
    query: Any,
    retrieve_type: str,
    **kwargs
) -> List[MemoryUnit]:
    """
    Factory method to retrieve memory units based on the specified retrieval type.

    Args:
        query (Any): The query for retrieval.
        retrieve_type (str): The type of retrieval ('similar' or 'related').
        **kwargs: Additional parameters required for specific retrieval types.
            For 'similar' retrieval:
                - top_k (int): The number of similar units to retrieve.
            For 'related' retrieval:
                - relationship (Optional[str]): The type of relationship to filter by.

    Returns:
        List[MemoryUnit]: A list of retrieved memory units.

    Raises:
        MemoryRetrieverError: If an unknown retrieval_type is provided or if retrieval fails.
    """
    self.logger.info(f"Initiating retrieval of type '{retrieve_type}' with query: {query}")
    try:
        if retrieve_type == 'similar':
            top_k = kwargs.get('top_k', 5)
            self.logger.debug(f"Retrieval Type: 'similar' with top_k={top_k}")
            return self.retrieve_similar(query, top_k)
        elif retrieve_type == 'related':
            relationship = kwargs.get('relationship')
            self.logger.debug(f"Retrieval Type: 'related' with relationship='{relationship}'")
            return self.retrieve_related(query, relationship)
        else:
            self.logger.error(f"Unknown retrieve_type: {retrieve_type}")
            raise MemoryRetrieverError(f"Unknown retrieve_type: {retrieve_type}")
    except MemoryRetrieverError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Failed to retrieve memory units: {e}")
        raise MemoryRetrieverError(f"Failed to retrieve memory units: {e}") from e

Retrieves memory units related to the given query based on relationships.

Parameters:

Name Type Description Default
query Any

The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.

required
relationship Optional[str]

The type of relationship to filter by.

None

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of related memory units.

Raises:

Type Description
MemoryRetrieverError

If retrieval fails due to storage issues or invalid queries.

Source code in src/aeiva/cognition/memory/memory_retriever.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def retrieve_related(
    self,
    query: Any,
    relationship: Optional[str] = None
) -> List[MemoryUnit]:  # TODO: revise the method later
    """
    Retrieves memory units related to the given query based on relationships.

    Args:
        query (Any): The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.
        relationship (Optional[str]): The type of relationship to filter by.

    Returns:
        List[MemoryUnit]: A list of related memory units.

    Raises:
        MemoryRetrieverError: If retrieval fails due to storage issues or invalid queries.
    """
    self.logger.info(f"Retrieving memories related to the query with relationship: {relationship}")
    try:
        # Assuming 'query' is a MemoryUnit ID or can be used to fetch a MemoryUnit
        self.logger.debug("Fetching the target MemoryUnit from storage.")
        target_memory_unit = self.storage.get_memory_unit(query)
        if not target_memory_unit:
            self.logger.error(f"MemoryUnit with ID '{query}' not found.")
            raise MemoryRetrieverError(f"MemoryUnit with ID '{query}' not found.")

        self.logger.debug(f"MemoryUnit fetched successfully: {target_memory_unit}")

        # Perform related retrieval via MemoryStorage
        self.logger.debug("Retrieving related MemoryUnits from the graph database.")
        related_units = self.storage.retrieve_related_memory_units(target_memory_unit.id, relationship)
        self.logger.info(f"Retrieved {len(related_units)} related MemoryUnits.")
        return related_units

    except MemoryRetrieverError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Unexpected error during retrieve_related: {e}")
        raise MemoryRetrieverError(f"Unexpected error during retrieve_related: {e}") from e
retrieve_similar(query, top_k=5)

Retrieves memory units similar to the given input based on embeddings.

Parameters:

Name Type Description Default
query Any

The query for retrieval.

required
top_k int

The number of similar units to retrieve.

5

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of similar memory units.

Raises:

Type Description
MemoryRetrieverError

If retrieval fails due to embedding generation or storage issues.

Source code in src/aeiva/cognition/memory/memory_retriever.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def retrieve_similar(self, query: Any, top_k: int = 5) -> List[MemoryUnit]:
    """
    Retrieves memory units similar to the given input based on embeddings.

    Args:
        query (Any): The query for retrieval.
        top_k (int): The number of similar units to retrieve.

    Returns:
        List[MemoryUnit]: A list of similar memory units.

    Raises:
        MemoryRetrieverError: If retrieval fails due to embedding generation or storage issues.
    """
    self.logger.info(f"Retrieving top {top_k} similar MemoryUnits based on the query.")
    try:
        # Generate embedding for the query
        self.logger.debug("Generating embedding for the query.")
        embedding_response = self.embedder.embed(query)
        if not embedding_response.get("data"):
            self.logger.error("Failed to generate embedding for the query.")
            raise MemoryRetrieverError("Failed to generate embedding for the query.")

        query_embedding = embedding_response["data"][0].get("embedding")
        if not query_embedding:
            self.logger.error("Embedding data is missing in the response.")
            raise MemoryRetrieverError("Embedding data is missing in the response.")

        self.logger.debug(f"Embedding generated successfully: {query_embedding}")

        # Perform similarity search via MemoryStorage
        self.logger.debug("Performing similarity search in the vector database.")
        similar_units = self.storage.retrieve_similar_memory_units(query_embedding, top_k)
        self.logger.info(f"Retrieved {len(similar_units)} similar MemoryUnits.")
        return similar_units

    except MemoryRetrieverError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Unexpected error during retrieve_similar: {e}")
        raise MemoryRetrieverError(f"Unexpected error during retrieve_similar: {e}") from e
MemoryRetrieverError

Bases: Exception

Exception raised when an error occurs in the MemoryRetriever.

Source code in src/aeiva/cognition/memory/memory_retriever.py
11
12
13
class MemoryRetrieverError(Exception):
    """Exception raised when an error occurs in the MemoryRetriever."""
    pass

memory_skillizer

MemorySkillizer

A class to skillize memory units based on various skillizing algorithms.

Supported skill types
  • 'skill_type_example': Placeholder for future skillizing algorithms.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class MemorySkillizer:
    """
    A class to skillize memory units based on various skillizing algorithms.

    Supported skill types:
        - 'skill_type_example': Placeholder for future skillizing algorithms.
    """

    def __init__(self):
        """
        Initializes the MemorySkillizer.

        Currently, no initialization parameters are required.
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.debug("Initialized MemorySkillizer without default parameters.")

    def skillize(
        self,
        memory_units: List[MemoryUnit],
        skill_type: str,
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Skillizes the provided memory units based on the specified skill type.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be skillized.
            skill_type (str): The type of skillizing algorithm to use ('skill_type_example').
            **kwargs: Additional parameters required for specific skillizers.

        Returns:
            List[MemoryUnit]: The list of memory units after skillizing.

        Raises:
            MemorySkillizerError: If an unknown skill_type is provided or if skillizing fails.
        """
        self.logger.debug(f"Skillizing memory units using skill_type='{skill_type}' with kwargs={kwargs}")
        try:
            if skill_type == 'skill_type_example':
                # Placeholder for actual skillizing logic
                return self.skillize_example(memory_units, **kwargs)
            else:
                self.logger.error(f"Unknown skill_type: {skill_type}")
                raise MemorySkillizerError(f"Unknown skill_type: {skill_type}")
        except MemorySkillizerError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Failed to skillize memory units: {e}")
            raise MemorySkillizerError(f"Failed to skillize memory units: {e}")

    def skillize_example(
        self,
        memory_units: List[MemoryUnit],
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Example skillizing method. Currently a placeholder that returns memory units unchanged.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be skillized.
            **kwargs: Additional parameters (currently unused).

        Returns:
            List[MemoryUnit]: The original list of memory units, unchanged.
        """
        self.logger.debug("Executing skillize_example: No changes applied to memory units.")
        # Placeholder: No operation performed
        return memory_units
__init__()

Initializes the MemorySkillizer.

Currently, no initialization parameters are required.

Source code in src/aeiva/cognition/memory/memory_skillizer.py
23
24
25
26
27
28
29
30
def __init__(self):
    """
    Initializes the MemorySkillizer.

    Currently, no initialization parameters are required.
    """
    self.logger = logging.getLogger(self.__class__.__name__)
    self.logger.debug("Initialized MemorySkillizer without default parameters.")
skillize(memory_units, skill_type, **kwargs)

Skillizes the provided memory units based on the specified skill type.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be skillized.

required
skill_type str

The type of skillizing algorithm to use ('skill_type_example').

required
**kwargs

Additional parameters required for specific skillizers.

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after skillizing.

Raises:

Type Description
MemorySkillizerError

If an unknown skill_type is provided or if skillizing fails.

Source code in src/aeiva/cognition/memory/memory_skillizer.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def skillize(
    self,
    memory_units: List[MemoryUnit],
    skill_type: str,
    **kwargs
) -> List[MemoryUnit]:
    """
    Skillizes the provided memory units based on the specified skill type.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be skillized.
        skill_type (str): The type of skillizing algorithm to use ('skill_type_example').
        **kwargs: Additional parameters required for specific skillizers.

    Returns:
        List[MemoryUnit]: The list of memory units after skillizing.

    Raises:
        MemorySkillizerError: If an unknown skill_type is provided or if skillizing fails.
    """
    self.logger.debug(f"Skillizing memory units using skill_type='{skill_type}' with kwargs={kwargs}")
    try:
        if skill_type == 'skill_type_example':
            # Placeholder for actual skillizing logic
            return self.skillize_example(memory_units, **kwargs)
        else:
            self.logger.error(f"Unknown skill_type: {skill_type}")
            raise MemorySkillizerError(f"Unknown skill_type: {skill_type}")
    except MemorySkillizerError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Failed to skillize memory units: {e}")
        raise MemorySkillizerError(f"Failed to skillize memory units: {e}")
skillize_example(memory_units, **kwargs)

Example skillizing method. Currently a placeholder that returns memory units unchanged.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be skillized.

required
**kwargs

Additional parameters (currently unused).

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The original list of memory units, unchanged.

Source code in src/aeiva/cognition/memory/memory_skillizer.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def skillize_example(
    self,
    memory_units: List[MemoryUnit],
    **kwargs
) -> List[MemoryUnit]:
    """
    Example skillizing method. Currently a placeholder that returns memory units unchanged.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be skillized.
        **kwargs: Additional parameters (currently unused).

    Returns:
        List[MemoryUnit]: The original list of memory units, unchanged.
    """
    self.logger.debug("Executing skillize_example: No changes applied to memory units.")
    # Placeholder: No operation performed
    return memory_units
MemorySkillizerError

Bases: Exception

Exception raised when an error occurs in the MemorySkillizer.

Source code in src/aeiva/cognition/memory/memory_skillizer.py
10
11
12
class MemorySkillizerError(Exception):
    """Exception raised when an error occurs in the MemorySkillizer."""
    pass

memory_storage

MemoryEventRepository

Repository for MemoryEvent to handle CRUD operations without SQLAlchemy.

Source code in src/aeiva/cognition/memory/memory_storage.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class MemoryEventRepository:
    """
    Repository for MemoryEvent to handle CRUD operations without SQLAlchemy.
    """

    def __init__(self, db: Any):
        """
        Initialize the repository with a DatabaseFactory instance.

        Args:
            db (Any): An instance of DatabaseFactory for relational databases.
        """
        self.db = db
        self.table_name = 'memory_events'
        self._create_table()

    def _create_table(self):
        """
        Creates the memory_events table if it does not exist.
        """
        create_table_query = f"""
        CREATE TABLE IF NOT EXISTS {self.table_name} (
            id TEXT PRIMARY KEY,
            memory_id TEXT NOT NULL,
            event_type TEXT NOT NULL,
            timestamp TEXT NOT NULL,
            memory_data TEXT,
            previous_data TEXT
        );
        """
        self.db.execute_sql(create_table_query)

    def add(self, event: Dict[str, Any]) -> None:
        """
        Adds a MemoryEvent to the relational database.

        Args:
            event (Dict[str, Any]): The event data to add.
        """
        insert_query = f"""
        INSERT INTO {self.table_name} (id, memory_id, event_type, timestamp, memory_data, previous_data)
        VALUES (?, ?, ?, ?, ?, ?);
        """
        data = (
            event.get('id', uuid4().hex),
            event['memory_id'],
            event['event_type'],
            datetime.utcnow().isoformat(),  # TODO: revise utcnow.
            event.get('memory_data'),
            event.get('previous_data')
        )
        self.db.execute_sql(insert_query, data)

    def get(self, event_id: str) -> Optional[Dict[str, Any]]:
        """
        Retrieves a MemoryEvent by its ID.

        Args:
            event_id (str): The unique identifier of the event.

        Returns:
            Optional[Dict[str, Any]]: The event data or None if not found.
        """
        select_query = f"SELECT * FROM {self.table_name} WHERE id = ?;"
        result = self.db.execute_sql(select_query, (event_id,))
        row = result.fetchone()
        if row:
            return self._row_to_event(row)
        return None

    def delete_all(self) -> None:
        """
        Deletes all MemoryEvents from the relational database.
        """
        delete_query = f"DELETE FROM {self.table_name};"
        self.db.execute_sql(delete_query)

    def list_all(self) -> List[Dict[str, Any]]:
        """
        Retrieves all MemoryEvents from the relational database.

        Returns:
            List[Dict[str, Any]]: A list of all events.
        """
        select_query = f"SELECT * FROM {self.table_name};"
        results = self.db.execute_sql(select_query)
        return [self._row_to_event(row) for row in results.fetchall()]

    def _row_to_event(self, row: Any) -> Dict[str, Any]:
        """
        Converts a database row to an event dictionary.

        Args:
            row (Any): A row fetched from the database.

        Returns:
            Dict[str, Any]: The corresponding event data.
        """
        return {
            "id": row['id'],
            "memory_id": row['memory_id'],
            "event_type": row['event_type'],
            "timestamp": datetime.fromisoformat(row['timestamp']),
            "memory_data": json.loads(row['memory_data']) if row['memory_data'] else None,
            "previous_data": json.loads(row['previous_data']) if row['previous_data'] else None
        }
__init__(db)

Initialize the repository with a DatabaseFactory instance.

Parameters:

Name Type Description Default
db Any

An instance of DatabaseFactory for relational databases.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
198
199
200
201
202
203
204
205
206
207
def __init__(self, db: Any):
    """
    Initialize the repository with a DatabaseFactory instance.

    Args:
        db (Any): An instance of DatabaseFactory for relational databases.
    """
    self.db = db
    self.table_name = 'memory_events'
    self._create_table()
add(event)

Adds a MemoryEvent to the relational database.

Parameters:

Name Type Description Default
event Dict[str, Any]

The event data to add.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def add(self, event: Dict[str, Any]) -> None:
    """
    Adds a MemoryEvent to the relational database.

    Args:
        event (Dict[str, Any]): The event data to add.
    """
    insert_query = f"""
    INSERT INTO {self.table_name} (id, memory_id, event_type, timestamp, memory_data, previous_data)
    VALUES (?, ?, ?, ?, ?, ?);
    """
    data = (
        event.get('id', uuid4().hex),
        event['memory_id'],
        event['event_type'],
        datetime.utcnow().isoformat(),  # TODO: revise utcnow.
        event.get('memory_data'),
        event.get('previous_data')
    )
    self.db.execute_sql(insert_query, data)
delete_all()

Deletes all MemoryEvents from the relational database.

Source code in src/aeiva/cognition/memory/memory_storage.py
263
264
265
266
267
268
def delete_all(self) -> None:
    """
    Deletes all MemoryEvents from the relational database.
    """
    delete_query = f"DELETE FROM {self.table_name};"
    self.db.execute_sql(delete_query)
get(event_id)

Retrieves a MemoryEvent by its ID.

Parameters:

Name Type Description Default
event_id str

The unique identifier of the event.

required

Returns:

Type Description
Optional[Dict[str, Any]]

Optional[Dict[str, Any]]: The event data or None if not found.

Source code in src/aeiva/cognition/memory/memory_storage.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def get(self, event_id: str) -> Optional[Dict[str, Any]]:
    """
    Retrieves a MemoryEvent by its ID.

    Args:
        event_id (str): The unique identifier of the event.

    Returns:
        Optional[Dict[str, Any]]: The event data or None if not found.
    """
    select_query = f"SELECT * FROM {self.table_name} WHERE id = ?;"
    result = self.db.execute_sql(select_query, (event_id,))
    row = result.fetchone()
    if row:
        return self._row_to_event(row)
    return None
list_all()

Retrieves all MemoryEvents from the relational database.

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of all events.

Source code in src/aeiva/cognition/memory/memory_storage.py
270
271
272
273
274
275
276
277
278
279
def list_all(self) -> List[Dict[str, Any]]:
    """
    Retrieves all MemoryEvents from the relational database.

    Returns:
        List[Dict[str, Any]]: A list of all events.
    """
    select_query = f"SELECT * FROM {self.table_name};"
    results = self.db.execute_sql(select_query)
    return [self._row_to_event(row) for row in results.fetchall()]
MemoryStorage

Handles storage operations for MemoryPalace, including interactions with vector, graph, and relational databases.

Source code in src/aeiva/cognition/memory/memory_storage.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
class MemoryStorage:
    """
    Handles storage operations for MemoryPalace, including interactions with vector,
    graph, and relational databases.
    """

    def __init__(self, config: Dict):
        """
        Initialize the MemoryStorage with the provided configuration.

        Args:
            config (Any): Configuration settings for MemoryStorage.
        """
        self.config_dict = config
        self.config = None
        self.setup()

    def setup(self) -> None:
        """
        Set up the MemoryStorage's components based on the provided configuration.
        """
        try:
            # Initialize Vector Database Configuration
            vector_db_conf_dict = self.config_dict.get('vector_db_config', {})
            vector_db_provider_name = vector_db_conf_dict.get('provider_name', 'milvus')
            vector_db_config = DatabaseConfigFactory.create(
                provider_name=vector_db_conf_dict.get('provider_name', 'milvus'),
                uri=vector_db_conf_dict.get('uri', 'storage/milvus_demo.db'),
                collection_name=vector_db_conf_dict.get('collection_name', 'test_collection'),
                embedding_model_dims=vector_db_conf_dict.get('embedding_model_dims', 1536),  # 'text-embedding-ada-002': 1536,
                metric_type=vector_db_conf_dict.get('metric_type', 'COSINE')
            )

            # Initialize Graph Database Configuration
            graph_db_conf_dict = self.config_dict.get('graph_db_config', {})
            graph_db_provider_name = graph_db_conf_dict.get('provider_name', 'neo4j')
            graph_db_password = graph_db_conf_dict.get('password')
            graph_db_config = DatabaseConfigFactory.create(
                provider_name=graph_db_conf_dict.get('provider_name', 'neo4j'),
                uri=graph_db_conf_dict.get('uri', 'bolt://localhost:7687'),
                user=graph_db_conf_dict.get('user', 'neo4j'),
                password=graph_db_password,
                database=graph_db_conf_dict.get('database', 'neo4j'),
                encrypted=graph_db_conf_dict.get('encrypted', False)
            )

            # Initialize Relational Database Configuration
            relational_db_conf_dict = self.config_dict.get('relational_db_config', {})
            relational_db_provider_name = relational_db_conf_dict.get('provider_name', 'sqlite')
            relational_db_config = DatabaseConfigFactory.create(
                provider_name=relational_db_conf_dict.get('provider_name', 'sqlite'),
                database=relational_db_conf_dict.get('database', 'storage/test_database.db')
            )

            self.config = StorageConfig(
                vector_db_provider=self.config_dict.get('vector_db_provider', 'milvus'),
                vector_db_config=vector_db_config,
                graph_db_provider=self.config_dict.get('graph_db_provider', 'neo4j'),
                graph_db_config=graph_db_config,
                relational_db_provider=self.config_dict.get('relational_db_provider', 'sqlite'),
                relational_db_config=relational_db_config,
            )

            # Initialize the vector database
            self.vector_db = DatabaseFactory.create(
                provider_name=vector_db_provider_name,
                config=self.config.vector_db_config
            )

            # Initialize the graph database if provided
            if graph_db_provider_name and self.config.graph_db_config:
                self.graph_db = DatabaseFactory.create(
                    provider_name=graph_db_provider_name,
                    config=self.config.graph_db_config
                )
            else:
                self.graph_db = None

            # Initialize the relational database if provided
            if relational_db_provider_name and self.config.relational_db_config:
                self.relational_db = DatabaseFactory.create(
                    provider_name=relational_db_provider_name,
                    config=self.config.relational_db_config
                )
                self.memory_unit_repo = MemoryUnitRepository(self.relational_db)
                self.memory_event_repo = MemoryEventRepository(self.relational_db)
            else:
                self.relational_db = None
                self.memory_unit_repo = None
                self.memory_event_repo = None

            logger.info("MemoryStorage setup completed successfully.")
        except Exception as e:
            logger.error(f"Error during MemoryStorage setup: {e}")
            self.handle_error(e)
            raise  # Re-raise the exception after logging

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during storage operations.

        Args:
            error (Exception): The exception that was raised.
        """
        logger.error(f"MemoryStorage encountered an error: {error}")
        # Additional error handling can be implemented here

    def add_memory_unit(self, memory_unit: MemoryUnit) -> None:
        """
        Adds a MemoryUnit to all configured databases.

        Args:
            memory_unit (MemoryUnit): The memory unit to add.
        """
        try:
            # Add to vector database
            self._add_to_vector_db(memory_unit)

            # Add to graph database
            if self.graph_db:
                self._add_to_graph_db(memory_unit)

            # Add to relational database
            if self.relational_db and self.memory_unit_repo:
                self._add_to_relational_db(memory_unit)

            # Record creation event
            if self.relational_db and self.memory_event_repo:
                self._record_event(
                    event_type="CREATE",
                    memory_unit=memory_unit
                )

            logger.info(f"Added MemoryUnit with ID: {memory_unit.id} to all databases.")
        except Exception as e:
            logger.error(f"Error adding MemoryUnit to databases: {e}")
            self.handle_error(e)
            raise

    def get_memory_unit(self, unit_id: str) -> MemoryUnit:
        """
        Retrieves a MemoryUnit by its unique identifier from the relational database.

        Args:
            unit_id (str): The unique identifier of the memory unit.

        Returns:
            MemoryUnit: The retrieved memory unit.
        """
        try:
            if not self.relational_db or not self.memory_unit_repo:
                raise ValueError("Relational database is not configured.")

            memory_unit = self.memory_unit_repo.get(unit_id)
            if not memory_unit:
                raise ValueError(f"MemoryUnit with ID {unit_id} does not exist.")

            logger.info(f"Retrieved MemoryUnit with ID: {unit_id} from Relational DB.")
            return memory_unit
        except Exception as e:
            logger.error(f"Error retrieving MemoryUnit with ID {unit_id}: {e}")
            self.handle_error(e)
            raise

    def update_memory_unit(self, unit_id: str, updates: Dict[str, Any]) -> None:
        """
        Updates a MemoryUnit in all configured databases.

        Args:
            unit_id (str): The unique identifier of the memory unit.
            updates (Dict[str, Any]): The updates to apply.
        """
        try:
            # Retrieve existing MemoryUnit
            memory_unit = self.get_memory_unit(unit_id)
            previous_state = memory_unit.to_dict()

            # Apply updates
            for key, value in updates.items():
                setattr(memory_unit, key, value)

            # Update in vector database
            self._update_vector_db(memory_unit)

            # Update in graph database
            if self.graph_db:
                self._update_graph_db(memory_unit)

            # Update in relational database
            if self.relational_db and self.memory_unit_repo:
                self._update_relational_db(memory_unit)

            # Record update event
            if self.relational_db and self.memory_event_repo:
                self._record_event(
                    event_type="UPDATE",
                    memory_unit=memory_unit,
                    previous_state=previous_state
                )

            logger.info(f"Updated MemoryUnit with ID: {unit_id} in all databases.")
        except Exception as e:
            logger.error(f"Error updating MemoryUnit with ID {unit_id}: {e}")
            self.handle_error(e)
            raise

    def delete_memory_unit(self, unit_id: str) -> None:
        """
        Deletes a MemoryUnit from all configured databases.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        try:
            # Retrieve existing MemoryUnit
            memory_unit = self.get_memory_unit(unit_id)

            # Delete from vector database
            self._delete_from_vector_db(unit_id)

            # Delete from graph database
            if self.graph_db:
                self._delete_from_graph_db(unit_id)

            # Delete from relational database
            if self.relational_db and self.memory_unit_repo:
                self._delete_relational_db(unit_id)

            # Record deletion event
            if self.relational_db and self.memory_event_repo:
                self._record_event(
                    event_type="DELETE",
                    memory_unit=memory_unit
                )

            logger.info(f"Deleted MemoryUnit with ID: {unit_id} from all databases.")
        except Exception as e:
            logger.error(f"Error deleting MemoryUnit with ID {unit_id}: {e}")
            self.handle_error(e)
            raise

    def get_all_memory_units(self) -> List[MemoryUnit]:
        """
        Retrieves all MemoryUnits from the relational database.

        Returns:
            List[MemoryUnit]: A list of all memory units.
        """
        try:
            if not self.relational_db or not self.memory_unit_repo:
                raise ValueError("Relational database is not configured.")

            memory_units = self.memory_unit_repo.list_all()
            logger.info(f"Retrieved all MemoryUnits from Relational DB. Total count: {len(memory_units)}")
            return memory_units
        except Exception as e:
            logger.error(f"Error retrieving all MemoryUnits: {e}")
            self.handle_error(e)
            raise

    def delete_all_memory_units(self) -> None:
        """
        Deletes all MemoryUnits from all configured databases.
        """
        try:
            # Delete from vector database
            self.vector_db.delete_collection(
                collection_name=self.config.vector_db_config.collection_name
            )

            # Delete all nodes from graph database
            if self.graph_db:
                self.graph_db.delete_all()

            # Delete all records from relational database
            if self.relational_db and self.memory_unit_repo and self.memory_event_repo:
                self.memory_unit_repo.delete_all()
                self.memory_event_repo.delete_all()

            logger.info("Deleted all MemoryUnits from all databases.")
        except Exception as e:
            logger.error(f"Error deleting all MemoryUnits: {e}")
            self.handle_error(e)
            raise

    # Internal helper methods

    def _add_to_vector_db(self, memory_unit: MemoryUnit) -> None:
        """
        Adds the embedding vector of a MemoryUnit to the vector database.

        Args:
            memory_unit (MemoryUnit): The memory unit to add.
        """
        try:
            # Ensure embedding exists
            if not memory_unit.embedding:
                raise ValueError("MemoryUnit does not have an embedding.")

            # Prepare payload with essential fields
            payload = {
                "id": memory_unit.id,
                "type": memory_unit.type,
                "modality": memory_unit.modality
            }

            # Insert into vector database
            self.vector_db.insert_vectors(
                collection_name=self.config.vector_db_config.collection_name,
                vectors=[memory_unit.embedding],
                payloads=[payload],
                ids=[memory_unit.id]
            )

            logger.info(f"Inserted embedding for MemoryUnit ID: {memory_unit.id} into Vector DB.")
        except Exception as e:
            logger.error(f"Error adding MemoryUnit to Vector DB: {e}")
            self.handle_error(e)
            raise

    def _update_vector_db(self, memory_unit: MemoryUnit) -> None:
        """
        Updates the embedding vector of a MemoryUnit in the vector database.

        Args:
            memory_unit (MemoryUnit): The memory unit to update.
        """
        try:
            if not memory_unit.embedding:
                raise ValueError("MemoryUnit does not have an embedding.")

            payload = {
                "type": memory_unit.type,
                "modality": memory_unit.modality
            }

            self.vector_db.update_vector(
                collection_name=self.config.vector_db_config.collection_name,
                vector_id=memory_unit.id,
                vector=memory_unit.embedding,
                payload=payload
            )

            logger.info(f"Updated embedding for MemoryUnit ID: {memory_unit.id} in Vector DB.")
        except Exception as e:
            logger.error(f"Error updating MemoryUnit in Vector DB: {e}")
            self.handle_error(e)
            raise

    def _delete_from_vector_db(self, unit_id: str) -> None:
        """
        Deletes a MemoryUnit's embedding from the vector database.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        try:
            self.vector_db.delete_vector(
                collection_name=self.config.vector_db_config.collection_name,
                vector_id=unit_id
            )

            logger.info(f"Deleted embedding for MemoryUnit ID: {unit_id} from Vector DB.")
        except Exception as e:
            logger.error(f"Error deleting MemoryUnit from Vector DB: {e}")
            self.handle_error(e)
            raise

    def _add_to_graph_db(self, memory_unit: MemoryUnit) -> None:
        """
        Adds a MemoryUnit as a node in the graph database and establishes relationships.

        Args:
            memory_unit (MemoryUnit): The memory unit to add.
        """
        try:
            # Serialize complex fields
            properties = {
                "id": memory_unit.id,
                "content": memory_unit.content,
                "timestamp": memory_unit.timestamp.isoformat(),
                "modality": memory_unit.modality,
                "type": memory_unit.type,
                "status": memory_unit.status,
                "tags": memory_unit.tags,
                "embedding": memory_unit.embedding,
                "location": json.dumps(memory_unit.location) if memory_unit.location else None,  # Serialized
                "source_role": memory_unit.source_role,
                "source_name": memory_unit.source_name,
                "source_id": memory_unit.source_id,
                "metadata": json.dumps(memory_unit.metadata) if memory_unit.metadata else None  # Serialized
            }

            # Add node to graph database
            self.graph_db.add_node(
                node_id=memory_unit.id,
                properties=properties,
                labels=[memory_unit.type or 'MemoryUnit']
            )

            logger.info(f"Added MemoryUnit ID: {memory_unit.id} to Graph DB.")

            # Add relationships (edges) if any
            for link in memory_unit.edges:
                # Serialize edge metadata if necessary
                edge_properties = {}
                if link.metadata:
                    edge_properties['metadata'] = json.dumps(link.metadata)

                self.graph_db.add_edge(
                    source_id=link.source_id,
                    target_id=link.target_id,
                    relationship=link.relationship,
                    properties=edge_properties
                )

            logger.info(f"Added {len(memory_unit.edges)} edges for MemoryUnit ID: {memory_unit.id} in Graph DB.")
        except Exception as e:
            logger.error(f"Error adding MemoryUnit to Graph DB: {e}")
            self.handle_error(e)
            raise

    def _update_graph_db(self, memory_unit: MemoryUnit) -> None:
        """
        Updates a MemoryUnit in the graph database.

        Args:
            memory_unit (MemoryUnit): The memory unit to update.
        """
        try:
            # Update node properties
            properties = {
                "content": memory_unit.content,
                "timestamp": memory_unit.timestamp.isoformat(),
                "modality": memory_unit.modality,
                "type": memory_unit.type,
                "status": memory_unit.status,
                "tags": memory_unit.tags,
                "embedding": memory_unit.embedding,
                "location": json.dumps(memory_unit.location) if memory_unit.location else None,  # Serialized
                "source_role": memory_unit.source_role,
                "source_name": memory_unit.source_name,
                "source_id": memory_unit.source_id,
                "metadata": json.dumps(memory_unit.metadata) if memory_unit.metadata else None  # Serialized
            }

            self.graph_db.update_node(
                node_id=memory_unit.id,
                properties=properties
            )

            # Handle edges updates as needed
            # This can be complex and depends on your specific requirements

            logger.info(f"Updated MemoryUnit ID: {memory_unit.id} in Graph DB.")
        except Exception as e:
            logger.error(f"Error updating MemoryUnit in Graph DB: {e}")
            self.handle_error(e)
            raise

    def _delete_from_graph_db(self, unit_id: str) -> None:
        """
        Deletes a MemoryUnit from the graph database.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        try:
            self.graph_db.delete_node(node_id=unit_id)
            logger.info(f"Deleted MemoryUnit ID: {unit_id} from Graph DB.")
        except Exception as e:
            logger.error(f"Error deleting MemoryUnit from Graph DB: {e}")
            self.handle_error(e)
            raise

    def _add_to_relational_db(self, memory_unit: MemoryUnit) -> None:
        """
        Adds a MemoryUnit to the relational database.

        Args:
            memory_unit (MemoryUnit): The memory unit to add.
        """
        try:
            self.memory_unit_repo.add(memory_unit)
            logger.info(f"Inserted MemoryUnit ID: {memory_unit.id} into Relational DB.")
        except Exception as e:
            logger.error(f"Error adding MemoryUnit to Relational DB: {e}")
            raise

    def _update_relational_db(self, memory_unit: MemoryUnit) -> None:
        """
        Updates a MemoryUnit in the relational database.

        Args:
            memory_unit (MemoryUnit): The memory unit to update.
        """
        try:
            self.memory_unit_repo.update(memory_unit)
            logger.info(f"Updated MemoryUnit ID: {memory_unit.id} in Relational DB.")
        except Exception as e:
            logger.error(f"Error updating MemoryUnit in Relational DB: {e}")
            raise

    def _delete_relational_db(self, unit_id: str) -> None:
        """
        Deletes a MemoryUnit from the relational database.

        Args:
            unit_id (str): The unique identifier of the memory unit.
        """
        try:
            self.memory_unit_repo.delete(unit_id)
            logger.info(f"Deleted MemoryUnit ID: {unit_id} from Relational DB.")
        except Exception as e:
            logger.error(f"Error deleting MemoryUnit from Relational DB: {e}")
            raise

    def _record_event(self, event_type: str, memory_unit: MemoryUnit, previous_state: Optional[Dict[str, Any]] = None) -> None:
        """
        Records an event in the relational database.

        Args:
            event_type (str): The type of event ('CREATE', 'UPDATE', 'DELETE').
            memory_unit (MemoryUnit): The memory unit involved in the event.
            previous_state (Optional[Dict[str, Any]]): The previous state of the memory unit (for updates).
        """
        try:
            event_record = {
                "memory_id": memory_unit.id,
                "event_type": event_type,
                "memory_data": json.dumps(memory_unit.to_dict()),
                "previous_data": json.dumps(previous_state) if previous_state else None
            }

            self.memory_event_repo.add(event_record)
            logger.info(f"Recorded event '{event_type}' for MemoryUnit ID: {memory_unit.id}.")
        except Exception as e:
            logger.error(f"Error recording event in Relational DB: {e}")
            raise

    def retrieve_similar_memory_units(self, query_embedding: List[float], top_k: int) -> List[MemoryUnit]:
        """
        Retrieves memory units similar to the given embedding.

        Args:
            query_embedding (List[float]): The embedding vector of the query.
            top_k (int): The number of similar units to retrieve.

        Returns:
            List[MemoryUnit]: A list of similar memory units.
        """
        try:
            # Perform similarity search
            results = self.vector_db.search_vectors(
                collection_name=self.config.vector_db_config.collection_name,
                query_vector=query_embedding,
                top_k=top_k
            )

            memory_units = []
            for result in results:
                unit_id = result['id']
                memory_unit = self.get_memory_unit(unit_id)
                memory_units.append(memory_unit)

            logger.info(f"Retrieved {len(memory_units)} similar MemoryUnits.")
            return memory_units
        except Exception as e:
            logger.error(f"Error retrieving similar MemoryUnits: {e}")
            self.handle_error(e)
            raise

    def retrieve_related_memory_units(self, unit_id: str, relationship: Optional[str] = None) -> List[MemoryUnit]:
        """
        Retrieves memory units related to the given one based on relationships.

        Args:
            unit_id (str): The unique identifier of the memory unit.
            relationship (Optional[str]): Filter by relationship type.

        Returns:
            List[MemoryUnit]: A list of related memory units.
        """
        try:
            if not self.graph_db:
                raise ValueError("Graph database is not configured.")

            # Retrieve related nodes from graph database
            neighbors = self.graph_db.get_neighbors(
                node_id=unit_id,
                relationship=relationship
            )

            related_units = []
            for neighbor in neighbors:
                related_unit = self.get_memory_unit(neighbor['id'])
                related_units.append(related_unit)

            logger.info(f"Retrieved {len(related_units)} related MemoryUnits.")
            return related_units
        except Exception as e:
            logger.error(f"Error retrieving related MemoryUnits: {e}")
            self.handle_error(e)
            raise
__init__(config)

Initialize the MemoryStorage with the provided configuration.

Parameters:

Name Type Description Default
config Any

Configuration settings for MemoryStorage.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
307
308
309
310
311
312
313
314
315
316
def __init__(self, config: Dict):
    """
    Initialize the MemoryStorage with the provided configuration.

    Args:
        config (Any): Configuration settings for MemoryStorage.
    """
    self.config_dict = config
    self.config = None
    self.setup()
add_memory_unit(memory_unit)

Adds a MemoryUnit to all configured databases.

Parameters:

Name Type Description Default
memory_unit MemoryUnit

The memory unit to add.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def add_memory_unit(self, memory_unit: MemoryUnit) -> None:
    """
    Adds a MemoryUnit to all configured databases.

    Args:
        memory_unit (MemoryUnit): The memory unit to add.
    """
    try:
        # Add to vector database
        self._add_to_vector_db(memory_unit)

        # Add to graph database
        if self.graph_db:
            self._add_to_graph_db(memory_unit)

        # Add to relational database
        if self.relational_db and self.memory_unit_repo:
            self._add_to_relational_db(memory_unit)

        # Record creation event
        if self.relational_db and self.memory_event_repo:
            self._record_event(
                event_type="CREATE",
                memory_unit=memory_unit
            )

        logger.info(f"Added MemoryUnit with ID: {memory_unit.id} to all databases.")
    except Exception as e:
        logger.error(f"Error adding MemoryUnit to databases: {e}")
        self.handle_error(e)
        raise
delete_all_memory_units()

Deletes all MemoryUnits from all configured databases.

Source code in src/aeiva/cognition/memory/memory_storage.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
def delete_all_memory_units(self) -> None:
    """
    Deletes all MemoryUnits from all configured databases.
    """
    try:
        # Delete from vector database
        self.vector_db.delete_collection(
            collection_name=self.config.vector_db_config.collection_name
        )

        # Delete all nodes from graph database
        if self.graph_db:
            self.graph_db.delete_all()

        # Delete all records from relational database
        if self.relational_db and self.memory_unit_repo and self.memory_event_repo:
            self.memory_unit_repo.delete_all()
            self.memory_event_repo.delete_all()

        logger.info("Deleted all MemoryUnits from all databases.")
    except Exception as e:
        logger.error(f"Error deleting all MemoryUnits: {e}")
        self.handle_error(e)
        raise
delete_memory_unit(unit_id)

Deletes a MemoryUnit from all configured databases.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def delete_memory_unit(self, unit_id: str) -> None:
    """
    Deletes a MemoryUnit from all configured databases.

    Args:
        unit_id (str): The unique identifier of the memory unit.
    """
    try:
        # Retrieve existing MemoryUnit
        memory_unit = self.get_memory_unit(unit_id)

        # Delete from vector database
        self._delete_from_vector_db(unit_id)

        # Delete from graph database
        if self.graph_db:
            self._delete_from_graph_db(unit_id)

        # Delete from relational database
        if self.relational_db and self.memory_unit_repo:
            self._delete_relational_db(unit_id)

        # Record deletion event
        if self.relational_db and self.memory_event_repo:
            self._record_event(
                event_type="DELETE",
                memory_unit=memory_unit
            )

        logger.info(f"Deleted MemoryUnit with ID: {unit_id} from all databases.")
    except Exception as e:
        logger.error(f"Error deleting MemoryUnit with ID {unit_id}: {e}")
        self.handle_error(e)
        raise
get_all_memory_units()

Retrieves all MemoryUnits from the relational database.

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of all memory units.

Source code in src/aeiva/cognition/memory/memory_storage.py
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def get_all_memory_units(self) -> List[MemoryUnit]:
    """
    Retrieves all MemoryUnits from the relational database.

    Returns:
        List[MemoryUnit]: A list of all memory units.
    """
    try:
        if not self.relational_db or not self.memory_unit_repo:
            raise ValueError("Relational database is not configured.")

        memory_units = self.memory_unit_repo.list_all()
        logger.info(f"Retrieved all MemoryUnits from Relational DB. Total count: {len(memory_units)}")
        return memory_units
    except Exception as e:
        logger.error(f"Error retrieving all MemoryUnits: {e}")
        self.handle_error(e)
        raise
get_memory_unit(unit_id)

Retrieves a MemoryUnit by its unique identifier from the relational database.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required

Returns:

Name Type Description
MemoryUnit MemoryUnit

The retrieved memory unit.

Source code in src/aeiva/cognition/memory/memory_storage.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def get_memory_unit(self, unit_id: str) -> MemoryUnit:
    """
    Retrieves a MemoryUnit by its unique identifier from the relational database.

    Args:
        unit_id (str): The unique identifier of the memory unit.

    Returns:
        MemoryUnit: The retrieved memory unit.
    """
    try:
        if not self.relational_db or not self.memory_unit_repo:
            raise ValueError("Relational database is not configured.")

        memory_unit = self.memory_unit_repo.get(unit_id)
        if not memory_unit:
            raise ValueError(f"MemoryUnit with ID {unit_id} does not exist.")

        logger.info(f"Retrieved MemoryUnit with ID: {unit_id} from Relational DB.")
        return memory_unit
    except Exception as e:
        logger.error(f"Error retrieving MemoryUnit with ID {unit_id}: {e}")
        self.handle_error(e)
        raise
handle_error(error)

Handle errors that occur during storage operations.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
398
399
400
401
402
403
404
405
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during storage operations.

    Args:
        error (Exception): The exception that was raised.
    """
    logger.error(f"MemoryStorage encountered an error: {error}")

Retrieves memory units related to the given one based on relationships.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
relationship Optional[str]

Filter by relationship type.

None

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of related memory units.

Source code in src/aeiva/cognition/memory/memory_storage.py
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
def retrieve_related_memory_units(self, unit_id: str, relationship: Optional[str] = None) -> List[MemoryUnit]:
    """
    Retrieves memory units related to the given one based on relationships.

    Args:
        unit_id (str): The unique identifier of the memory unit.
        relationship (Optional[str]): Filter by relationship type.

    Returns:
        List[MemoryUnit]: A list of related memory units.
    """
    try:
        if not self.graph_db:
            raise ValueError("Graph database is not configured.")

        # Retrieve related nodes from graph database
        neighbors = self.graph_db.get_neighbors(
            node_id=unit_id,
            relationship=relationship
        )

        related_units = []
        for neighbor in neighbors:
            related_unit = self.get_memory_unit(neighbor['id'])
            related_units.append(related_unit)

        logger.info(f"Retrieved {len(related_units)} related MemoryUnits.")
        return related_units
    except Exception as e:
        logger.error(f"Error retrieving related MemoryUnits: {e}")
        self.handle_error(e)
        raise
retrieve_similar_memory_units(query_embedding, top_k)

Retrieves memory units similar to the given embedding.

Parameters:

Name Type Description Default
query_embedding List[float]

The embedding vector of the query.

required
top_k int

The number of similar units to retrieve.

required

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of similar memory units.

Source code in src/aeiva/cognition/memory/memory_storage.py
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
def retrieve_similar_memory_units(self, query_embedding: List[float], top_k: int) -> List[MemoryUnit]:
    """
    Retrieves memory units similar to the given embedding.

    Args:
        query_embedding (List[float]): The embedding vector of the query.
        top_k (int): The number of similar units to retrieve.

    Returns:
        List[MemoryUnit]: A list of similar memory units.
    """
    try:
        # Perform similarity search
        results = self.vector_db.search_vectors(
            collection_name=self.config.vector_db_config.collection_name,
            query_vector=query_embedding,
            top_k=top_k
        )

        memory_units = []
        for result in results:
            unit_id = result['id']
            memory_unit = self.get_memory_unit(unit_id)
            memory_units.append(memory_unit)

        logger.info(f"Retrieved {len(memory_units)} similar MemoryUnits.")
        return memory_units
    except Exception as e:
        logger.error(f"Error retrieving similar MemoryUnits: {e}")
        self.handle_error(e)
        raise
setup()

Set up the MemoryStorage's components based on the provided configuration.

Source code in src/aeiva/cognition/memory/memory_storage.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
def setup(self) -> None:
    """
    Set up the MemoryStorage's components based on the provided configuration.
    """
    try:
        # Initialize Vector Database Configuration
        vector_db_conf_dict = self.config_dict.get('vector_db_config', {})
        vector_db_provider_name = vector_db_conf_dict.get('provider_name', 'milvus')
        vector_db_config = DatabaseConfigFactory.create(
            provider_name=vector_db_conf_dict.get('provider_name', 'milvus'),
            uri=vector_db_conf_dict.get('uri', 'storage/milvus_demo.db'),
            collection_name=vector_db_conf_dict.get('collection_name', 'test_collection'),
            embedding_model_dims=vector_db_conf_dict.get('embedding_model_dims', 1536),  # 'text-embedding-ada-002': 1536,
            metric_type=vector_db_conf_dict.get('metric_type', 'COSINE')
        )

        # Initialize Graph Database Configuration
        graph_db_conf_dict = self.config_dict.get('graph_db_config', {})
        graph_db_provider_name = graph_db_conf_dict.get('provider_name', 'neo4j')
        graph_db_password = graph_db_conf_dict.get('password')
        graph_db_config = DatabaseConfigFactory.create(
            provider_name=graph_db_conf_dict.get('provider_name', 'neo4j'),
            uri=graph_db_conf_dict.get('uri', 'bolt://localhost:7687'),
            user=graph_db_conf_dict.get('user', 'neo4j'),
            password=graph_db_password,
            database=graph_db_conf_dict.get('database', 'neo4j'),
            encrypted=graph_db_conf_dict.get('encrypted', False)
        )

        # Initialize Relational Database Configuration
        relational_db_conf_dict = self.config_dict.get('relational_db_config', {})
        relational_db_provider_name = relational_db_conf_dict.get('provider_name', 'sqlite')
        relational_db_config = DatabaseConfigFactory.create(
            provider_name=relational_db_conf_dict.get('provider_name', 'sqlite'),
            database=relational_db_conf_dict.get('database', 'storage/test_database.db')
        )

        self.config = StorageConfig(
            vector_db_provider=self.config_dict.get('vector_db_provider', 'milvus'),
            vector_db_config=vector_db_config,
            graph_db_provider=self.config_dict.get('graph_db_provider', 'neo4j'),
            graph_db_config=graph_db_config,
            relational_db_provider=self.config_dict.get('relational_db_provider', 'sqlite'),
            relational_db_config=relational_db_config,
        )

        # Initialize the vector database
        self.vector_db = DatabaseFactory.create(
            provider_name=vector_db_provider_name,
            config=self.config.vector_db_config
        )

        # Initialize the graph database if provided
        if graph_db_provider_name and self.config.graph_db_config:
            self.graph_db = DatabaseFactory.create(
                provider_name=graph_db_provider_name,
                config=self.config.graph_db_config
            )
        else:
            self.graph_db = None

        # Initialize the relational database if provided
        if relational_db_provider_name and self.config.relational_db_config:
            self.relational_db = DatabaseFactory.create(
                provider_name=relational_db_provider_name,
                config=self.config.relational_db_config
            )
            self.memory_unit_repo = MemoryUnitRepository(self.relational_db)
            self.memory_event_repo = MemoryEventRepository(self.relational_db)
        else:
            self.relational_db = None
            self.memory_unit_repo = None
            self.memory_event_repo = None

        logger.info("MemoryStorage setup completed successfully.")
    except Exception as e:
        logger.error(f"Error during MemoryStorage setup: {e}")
        self.handle_error(e)
        raise  # Re-raise the exception after logging
update_memory_unit(unit_id, updates)

Updates a MemoryUnit in all configured databases.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required
updates Dict[str, Any]

The updates to apply.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def update_memory_unit(self, unit_id: str, updates: Dict[str, Any]) -> None:
    """
    Updates a MemoryUnit in all configured databases.

    Args:
        unit_id (str): The unique identifier of the memory unit.
        updates (Dict[str, Any]): The updates to apply.
    """
    try:
        # Retrieve existing MemoryUnit
        memory_unit = self.get_memory_unit(unit_id)
        previous_state = memory_unit.to_dict()

        # Apply updates
        for key, value in updates.items():
            setattr(memory_unit, key, value)

        # Update in vector database
        self._update_vector_db(memory_unit)

        # Update in graph database
        if self.graph_db:
            self._update_graph_db(memory_unit)

        # Update in relational database
        if self.relational_db and self.memory_unit_repo:
            self._update_relational_db(memory_unit)

        # Record update event
        if self.relational_db and self.memory_event_repo:
            self._record_event(
                event_type="UPDATE",
                memory_unit=memory_unit,
                previous_state=previous_state
            )

        logger.info(f"Updated MemoryUnit with ID: {unit_id} in all databases.")
    except Exception as e:
        logger.error(f"Error updating MemoryUnit with ID {unit_id}: {e}")
        self.handle_error(e)
        raise
MemoryUnitRepository

Repository for MemoryUnit to handle CRUD operations without SQLAlchemy.

Source code in src/aeiva/cognition/memory/memory_storage.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class MemoryUnitRepository:
    """
    Repository for MemoryUnit to handle CRUD operations without SQLAlchemy.
    """

    def __init__(self, db: Any):
        """
        Initialize the repository with a DatabaseFactory instance.

        Args:
            db (Any): An instance of DatabaseFactory for relational databases.
        """
        self.db = db
        self.table_name = 'memory_units'
        self._create_table()

    def _create_table(self):
        """
        Creates the memory_units table if it does not exist.
        """
        create_table_query = f"""
        CREATE TABLE IF NOT EXISTS {self.table_name} (
            id TEXT PRIMARY KEY,
            content TEXT NOT NULL,
            timestamp TEXT NOT NULL,
            modality TEXT,
            type TEXT,
            status TEXT,
            tags TEXT,
            embedding TEXT,
            location TEXT,
            source_role TEXT,
            source_name TEXT,
            source_id TEXT,
            edges TEXT,
            metadata TEXT
        );
        """
        self.db.execute_sql(create_table_query)

    def add(self, memory_unit: MemoryUnit) -> None:
        """
        Adds a MemoryUnit to the relational database.

        Args:
            memory_unit (MemoryUnit): The memory unit to add.
        """
        insert_query = f"""
        INSERT INTO {self.table_name} (id, content, timestamp, modality, type, status, tags, embedding, location, 
            source_role, source_name, source_id, edges, metadata)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
        """
        data = (
            memory_unit.id,
            memory_unit.content,
            memory_unit.timestamp.isoformat(),
            memory_unit.modality,
            memory_unit.type,
            memory_unit.status,
            json.dumps(memory_unit.tags),
            json.dumps(memory_unit.embedding) if memory_unit.embedding else None,
            json.dumps(memory_unit.location) if memory_unit.location else None,
            memory_unit.source_role,
            memory_unit.source_name,
            memory_unit.source_id,
            json.dumps([link.to_dict() for link in memory_unit.edges]),
            json.dumps(memory_unit.metadata) if memory_unit.metadata else None
        )
        self.db.execute_sql(insert_query, data)

    def get(self, unit_id: str) -> Optional[MemoryUnit]:
        """
        Retrieves a MemoryUnit by its ID.

        Args:
            unit_id (str): The unique identifier of the memory unit.

        Returns:
            Optional[MemoryUnit]: The retrieved memory unit or None if not found.
        """
        select_query = f"SELECT * FROM {self.table_name} WHERE id = ?;"
        result = self.db.execute_sql(select_query, (unit_id,))
        row = result.fetchone()
        if row:
            return self._row_to_memory_unit(row)
        return None

    def update(self, memory_unit: MemoryUnit) -> None:
        """
        Updates an existing MemoryUnit in the relational database.

        Args:
            memory_unit (MemoryUnit): The memory unit with updated data.
        """
        update_query = f"""
        UPDATE {self.table_name}
        SET content = ?, timestamp = ?, modality = ?, type = ?, status = ?, tags = ?, embedding = ?, 
            location = ?, source_role = ?, source_name = ?, source_id = ?, edges = ?, metadata = ?
        WHERE id = ?;
        """
        data = (
            memory_unit.content,
            memory_unit.timestamp.isoformat(),
            memory_unit.modality,
            memory_unit.type,
            memory_unit.status,
            json.dumps(memory_unit.tags),
            json.dumps(memory_unit.embedding) if memory_unit.embedding else None,
            json.dumps(memory_unit.location) if memory_unit.location else None,
            memory_unit.source_role,
            memory_unit.source_name,
            memory_unit.source_id,
            json.dumps([link.to_dict() for link in memory_unit.edges]),
            json.dumps(memory_unit.metadata) if memory_unit.metadata else None,
            memory_unit.id
        )
        self.db.execute_sql(update_query, data)

    def delete(self, unit_id: str) -> None:
        """
        Deletes a MemoryUnit from the relational database.

        Args:
            unit_id (str): The unique identifier of the memory unit to delete.
        """
        delete_query = f"DELETE FROM {self.table_name} WHERE id = ?;"
        self.db.execute_sql(delete_query, (unit_id,))

    def list_all(self) -> List[MemoryUnit]:
        """
        Retrieves all MemoryUnits from the relational database.

        Returns:
            List[MemoryUnit]: A list of all memory units.
        """
        select_query = f"SELECT * FROM {self.table_name};"
        results = self.db.execute_sql(select_query)
        return [self._row_to_memory_unit(row) for row in results.fetchall()]

    def delete_all(self) -> None:
        """
        Deletes all MemoryUnits from the relational database.
        """
        delete_query = f"DELETE FROM {self.table_name};"
        self.db.execute_sql(delete_query)

    def _row_to_memory_unit(self, row: Any) -> MemoryUnit:
        """
        Converts a database row to a MemoryUnit instance.

        Args:
            row (Any): A row fetched from the database.

        Returns:
            MemoryUnit: The corresponding MemoryUnit instance.
        """
        return MemoryUnit(
            id=row['id'],
            content=row['content'],
            timestamp=datetime.fromisoformat(row['timestamp']),
            modality=row['modality'],
            type=row['type'],
            status=row['status'],
            tags=json.loads(row['tags']) if row['tags'] else [],
            embedding=json.loads(row['embedding']) if row['embedding'] else [],
            location=json.loads(row['location']) if row['location'] else {},
            source_role=row['source_role'],
            source_name=row['source_name'],
            source_id=row['source_id'],
            edges=[MemoryLink.from_dict(link) for link in json.loads(row['edges'])] if row['edges'] else [],
            metadata=json.loads(row['metadata']) if row['metadata'] else {}
        )
__init__(db)

Initialize the repository with a DatabaseFactory instance.

Parameters:

Name Type Description Default
db Any

An instance of DatabaseFactory for relational databases.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
24
25
26
27
28
29
30
31
32
33
def __init__(self, db: Any):
    """
    Initialize the repository with a DatabaseFactory instance.

    Args:
        db (Any): An instance of DatabaseFactory for relational databases.
    """
    self.db = db
    self.table_name = 'memory_units'
    self._create_table()
add(memory_unit)

Adds a MemoryUnit to the relational database.

Parameters:

Name Type Description Default
memory_unit MemoryUnit

The memory unit to add.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def add(self, memory_unit: MemoryUnit) -> None:
    """
    Adds a MemoryUnit to the relational database.

    Args:
        memory_unit (MemoryUnit): The memory unit to add.
    """
    insert_query = f"""
    INSERT INTO {self.table_name} (id, content, timestamp, modality, type, status, tags, embedding, location, 
        source_role, source_name, source_id, edges, metadata)
    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
    """
    data = (
        memory_unit.id,
        memory_unit.content,
        memory_unit.timestamp.isoformat(),
        memory_unit.modality,
        memory_unit.type,
        memory_unit.status,
        json.dumps(memory_unit.tags),
        json.dumps(memory_unit.embedding) if memory_unit.embedding else None,
        json.dumps(memory_unit.location) if memory_unit.location else None,
        memory_unit.source_role,
        memory_unit.source_name,
        memory_unit.source_id,
        json.dumps([link.to_dict() for link in memory_unit.edges]),
        json.dumps(memory_unit.metadata) if memory_unit.metadata else None
    )
    self.db.execute_sql(insert_query, data)
delete(unit_id)

Deletes a MemoryUnit from the relational database.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit to delete.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
137
138
139
140
141
142
143
144
145
def delete(self, unit_id: str) -> None:
    """
    Deletes a MemoryUnit from the relational database.

    Args:
        unit_id (str): The unique identifier of the memory unit to delete.
    """
    delete_query = f"DELETE FROM {self.table_name} WHERE id = ?;"
    self.db.execute_sql(delete_query, (unit_id,))
delete_all()

Deletes all MemoryUnits from the relational database.

Source code in src/aeiva/cognition/memory/memory_storage.py
158
159
160
161
162
163
def delete_all(self) -> None:
    """
    Deletes all MemoryUnits from the relational database.
    """
    delete_query = f"DELETE FROM {self.table_name};"
    self.db.execute_sql(delete_query)
get(unit_id)

Retrieves a MemoryUnit by its ID.

Parameters:

Name Type Description Default
unit_id str

The unique identifier of the memory unit.

required

Returns:

Type Description
Optional[MemoryUnit]

Optional[MemoryUnit]: The retrieved memory unit or None if not found.

Source code in src/aeiva/cognition/memory/memory_storage.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def get(self, unit_id: str) -> Optional[MemoryUnit]:
    """
    Retrieves a MemoryUnit by its ID.

    Args:
        unit_id (str): The unique identifier of the memory unit.

    Returns:
        Optional[MemoryUnit]: The retrieved memory unit or None if not found.
    """
    select_query = f"SELECT * FROM {self.table_name} WHERE id = ?;"
    result = self.db.execute_sql(select_query, (unit_id,))
    row = result.fetchone()
    if row:
        return self._row_to_memory_unit(row)
    return None
list_all()

Retrieves all MemoryUnits from the relational database.

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: A list of all memory units.

Source code in src/aeiva/cognition/memory/memory_storage.py
147
148
149
150
151
152
153
154
155
156
def list_all(self) -> List[MemoryUnit]:
    """
    Retrieves all MemoryUnits from the relational database.

    Returns:
        List[MemoryUnit]: A list of all memory units.
    """
    select_query = f"SELECT * FROM {self.table_name};"
    results = self.db.execute_sql(select_query)
    return [self._row_to_memory_unit(row) for row in results.fetchall()]
update(memory_unit)

Updates an existing MemoryUnit in the relational database.

Parameters:

Name Type Description Default
memory_unit MemoryUnit

The memory unit with updated data.

required
Source code in src/aeiva/cognition/memory/memory_storage.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def update(self, memory_unit: MemoryUnit) -> None:
    """
    Updates an existing MemoryUnit in the relational database.

    Args:
        memory_unit (MemoryUnit): The memory unit with updated data.
    """
    update_query = f"""
    UPDATE {self.table_name}
    SET content = ?, timestamp = ?, modality = ?, type = ?, status = ?, tags = ?, embedding = ?, 
        location = ?, source_role = ?, source_name = ?, source_id = ?, edges = ?, metadata = ?
    WHERE id = ?;
    """
    data = (
        memory_unit.content,
        memory_unit.timestamp.isoformat(),
        memory_unit.modality,
        memory_unit.type,
        memory_unit.status,
        json.dumps(memory_unit.tags),
        json.dumps(memory_unit.embedding) if memory_unit.embedding else None,
        json.dumps(memory_unit.location) if memory_unit.location else None,
        memory_unit.source_role,
        memory_unit.source_name,
        memory_unit.source_id,
        json.dumps([link.to_dict() for link in memory_unit.edges]),
        json.dumps(memory_unit.metadata) if memory_unit.metadata else None,
        memory_unit.id
    )
    self.db.execute_sql(update_query, data)

memory_structurer

MemoryStructurer

A class to structure memory units based on various structuring algorithms.

Supported structure types
  • 'structure_type_example': Placeholder for future structuring algorithms.
Source code in src/aeiva/cognition/memory/memory_structurer.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class MemoryStructurer:
    """
    A class to structure memory units based on various structuring algorithms.

    Supported structure types:
        - 'structure_type_example': Placeholder for future structuring algorithms.
    """

    def __init__(self):
        """
        Initializes the MemoryStructurer.

        Currently, no initialization parameters are required.
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.debug("Initialized MemoryStructurer without default parameters.")

    def structure(
        self,
        memory_units: List[MemoryUnit],
        structure_type: str,
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Structures the provided memory units based on the specified structure type.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be structured.
            structure_type (str): The type of structuring algorithm to use ('structure_type_example').
            **kwargs: Additional parameters required for specific structurers.

        Returns:
            List[MemoryUnit]: The list of memory units after structuring.

        Raises:
            MemoryStructurerError: If an unknown structure_type is provided or if structuring fails.
        """
        self.logger.debug(f"Structuring memory units using structure_type='{structure_type}' with kwargs={kwargs}")
        try:
            if structure_type == 'structure_type_example':
                # Placeholder for actual structuring logic
                return self.structure_example(memory_units, **kwargs)
            else:
                self.logger.error(f"Unknown structure_type: {structure_type}")
                raise MemoryStructurerError(f"Unknown structure_type: {structure_type}")
        except MemoryStructurerError:
            # Re-raise custom errors without modification
            raise
        except Exception as e:
            self.logger.error(f"Failed to structure memory units: {e}")
            raise MemoryStructurerError(f"Failed to structure memory units: {e}")

    def structure_example(
        self,
        memory_units: List[MemoryUnit],
        **kwargs
    ) -> List[MemoryUnit]:
        """
        Example structuring method. Currently a placeholder that returns memory units unchanged.

        Args:
            memory_units (List[MemoryUnit]): The list of memory units to be structured.
            **kwargs: Additional parameters (currently unused).

        Returns:
            List[MemoryUnit]: The original list of memory units, unchanged.
        """
        self.logger.debug("Executing structure_example: No changes applied to memory units.")
        # Placeholder: No operation performed
        return memory_units
__init__()

Initializes the MemoryStructurer.

Currently, no initialization parameters are required.

Source code in src/aeiva/cognition/memory/memory_structurer.py
23
24
25
26
27
28
29
30
def __init__(self):
    """
    Initializes the MemoryStructurer.

    Currently, no initialization parameters are required.
    """
    self.logger = logging.getLogger(self.__class__.__name__)
    self.logger.debug("Initialized MemoryStructurer without default parameters.")
structure(memory_units, structure_type, **kwargs)

Structures the provided memory units based on the specified structure type.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be structured.

required
structure_type str

The type of structuring algorithm to use ('structure_type_example').

required
**kwargs

Additional parameters required for specific structurers.

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The list of memory units after structuring.

Raises:

Type Description
MemoryStructurerError

If an unknown structure_type is provided or if structuring fails.

Source code in src/aeiva/cognition/memory/memory_structurer.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def structure(
    self,
    memory_units: List[MemoryUnit],
    structure_type: str,
    **kwargs
) -> List[MemoryUnit]:
    """
    Structures the provided memory units based on the specified structure type.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be structured.
        structure_type (str): The type of structuring algorithm to use ('structure_type_example').
        **kwargs: Additional parameters required for specific structurers.

    Returns:
        List[MemoryUnit]: The list of memory units after structuring.

    Raises:
        MemoryStructurerError: If an unknown structure_type is provided or if structuring fails.
    """
    self.logger.debug(f"Structuring memory units using structure_type='{structure_type}' with kwargs={kwargs}")
    try:
        if structure_type == 'structure_type_example':
            # Placeholder for actual structuring logic
            return self.structure_example(memory_units, **kwargs)
        else:
            self.logger.error(f"Unknown structure_type: {structure_type}")
            raise MemoryStructurerError(f"Unknown structure_type: {structure_type}")
    except MemoryStructurerError:
        # Re-raise custom errors without modification
        raise
    except Exception as e:
        self.logger.error(f"Failed to structure memory units: {e}")
        raise MemoryStructurerError(f"Failed to structure memory units: {e}")
structure_example(memory_units, **kwargs)

Example structuring method. Currently a placeholder that returns memory units unchanged.

Parameters:

Name Type Description Default
memory_units List[MemoryUnit]

The list of memory units to be structured.

required
**kwargs

Additional parameters (currently unused).

{}

Returns:

Type Description
List[MemoryUnit]

List[MemoryUnit]: The original list of memory units, unchanged.

Source code in src/aeiva/cognition/memory/memory_structurer.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def structure_example(
    self,
    memory_units: List[MemoryUnit],
    **kwargs
) -> List[MemoryUnit]:
    """
    Example structuring method. Currently a placeholder that returns memory units unchanged.

    Args:
        memory_units (List[MemoryUnit]): The list of memory units to be structured.
        **kwargs: Additional parameters (currently unused).

    Returns:
        List[MemoryUnit]: The original list of memory units, unchanged.
    """
    self.logger.debug("Executing structure_example: No changes applied to memory units.")
    # Placeholder: No operation performed
    return memory_units
MemoryStructurerError

Bases: Exception

Exception raised when an error occurs in the MemoryStructurer.

Source code in src/aeiva/cognition/memory/memory_structurer.py
10
11
12
class MemoryStructurerError(Exception):
    """Exception raised when an error occurs in the MemoryStructurer."""
    pass

memory_unit

MemoryUnit

Bases: BaseModel

MemoryUnit represents a single unit of memory with core content and rich metadata. It includes fields for tracking information about the memory’s source, modality, temporal and spatial attributes, and its connections to other memory units.

Essential Fields

id (str): Unique identifier for the memory unit, generated as a UUID string by default. content (Any): Core content of the memory, which is convertible to a string.

Source Information

source_role (Optional[str]): Role of the source, e.g., 'user', 'agent'. source_name (Optional[str]): Descriptive name of the source. source_id (Optional[str]): Unique identifier for the memory source, generated as a UUID string.

Connections

edges (List[MemoryLink]): List of edges connecting this memory unit to others.

Source code in src/aeiva/cognition/memory/memory_unit.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class MemoryUnit(BaseModel):
    """
    MemoryUnit represents a single unit of memory with core content and rich metadata.
    It includes fields for tracking information about the memory’s source, modality,
    temporal and spatial attributes, and its connections to other memory units.

    Essential Fields:
        id (str): Unique identifier for the memory unit, generated as a UUID string by default.
        content (Any): Core content of the memory, which is convertible to a string.

    Metadata:
        timestamp (datetime): Creation timestamp, defaulting to the current time.
        modality (Optional[str]): Modality type, such as 'text', 'image', 'audio'.
        type (Optional[str]): Semantic type, such as 'dialogue', 'summary', 'document'.
        status (Optional[str]): Processing status, e.g., 'raw', 'cleaned', 'processed'.
        tags (Optional[List[str]]): Tags for categorization and filtering.
        embedding (Optional[List[float]]): Vector embedding for retrieval.
        location (Optional[Union[str, Dict]]): Spatial location data.

    Source Information:
        source_role (Optional[str]): Role of the source, e.g., 'user', 'agent'.
        source_name (Optional[str]): Descriptive name of the source.
        source_id (Optional[str]): Unique identifier for the memory source, generated as a UUID string.

    Connections:
        edges (List[MemoryLink]): List of edges connecting this memory unit to others.

    Additional Metadata:
        metadata (Optional[Dict[str, Any]]): Dictionary for extensible metadata.
    """

    # Essential Fields
    id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the memory unit.")
    content: Any = Field("", description="Core content of the memory unit, convertible to a string.")

    # Metadata Fields
    timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="Creation timestamp of the memory.")
    modality: Optional[str] = Field(None, description="Modality type, e.g., 'text', 'image', 'audio'.")
    type: Optional[str] = Field(None, description="Semantic type, e.g., 'dialogue', 'summary'.")
    status: Optional[str] = Field(None, description="Processing status, e.g., 'raw', 'cleaned', 'derived', 'grouped', 'structured', 'indexed'.")
    tags: Optional[List[str]] = Field(default_factory=list, description="Tags for categorization or filtering.")
    embedding: Optional[List[float]] = Field(None, description="Embedding vector for memory.")
    location: Optional[Union[str, Dict]] = Field(None, description="Location data as a string or structured dictionary.")

    # Source Information
    source_role: Optional[str] = Field(None, description="Role of the memory source, e.g., 'user', 'agent'.")
    source_name: Optional[str] = Field(None, description="Descriptive name of the source, e.g., 'User123'.")
    source_id: Optional[str] = Field(default_factory=lambda: uuid4().hex, description="Unique identifier associated with the source.")

    # Connections
    edges: List[MemoryLink] = Field(default_factory=list, description="List of edges linking this memory unit to others.")

    # Additional Metadata
    metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Dictionary for extensible metadata.")

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the MemoryUnit instance to a dictionary format for serialization.
        Each field is handled explicitly to ensure proper serialization.

        Returns:
            Dict[str, Any]: A dictionary representation of the MemoryUnit.
        """
        return {
            "id": self.id,
            "content": self.content,
            "timestamp": self.timestamp.isoformat(),  # Convert datetime to string
            "modality": self.modality,
            "type": self.type,
            "status": self.status,
            "tags": self.tags,
            "embedding": self.embedding,
            "location": self.location,
            "source_role": self.source_role,
            "source_name": self.source_name,
            "source_id": self.source_id,
            "edges": [edge.to_dict() for edge in self.edges],  # Serialize each MemoryLink
            "metadata": self.metadata
        }

    @classmethod
    def from_dict(cls, data: dict) -> "MemoryUnit":
        """
        Creates a MemoryUnit instance from a dictionary.
        Each field is handled explicitly to ensure proper deserialization.

        Args:
            data (dict): A dictionary containing MemoryUnit data.

        Returns:
            MemoryUnit: The created MemoryUnit instance.
        """
        try:
            return cls(
                id=data.get('id', uuid4().hex),
                content=data.get('content', ""),
                timestamp=datetime.fromisoformat(data['timestamp']) if 'timestamp' in data else datetime.now(UTC),
                modality=data.get('modality'),
                type=data.get('type'),
                status=data.get('status'),
                tags=data.get('tags', []),
                embedding=data.get('embedding'),
                location=data.get('location'),
                source_role=data.get('source_role'),
                source_name=data.get('source_name'),
                source_id=data.get('source_id', uuid4().hex),
                edges=[MemoryLink.from_dict(edge) for edge in data.get('edges', [])],
                metadata=data.get('metadata', {})
            )
        except Exception as e:
            # logger.error(f"Error deserializing MemoryUnit from dict: {e}")
            raise e
from_dict(data) classmethod

Creates a MemoryUnit instance from a dictionary. Each field is handled explicitly to ensure proper deserialization.

Parameters:

Name Type Description Default
data dict

A dictionary containing MemoryUnit data.

required

Returns:

Name Type Description
MemoryUnit MemoryUnit

The created MemoryUnit instance.

Source code in src/aeiva/cognition/memory/memory_unit.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@classmethod
def from_dict(cls, data: dict) -> "MemoryUnit":
    """
    Creates a MemoryUnit instance from a dictionary.
    Each field is handled explicitly to ensure proper deserialization.

    Args:
        data (dict): A dictionary containing MemoryUnit data.

    Returns:
        MemoryUnit: The created MemoryUnit instance.
    """
    try:
        return cls(
            id=data.get('id', uuid4().hex),
            content=data.get('content', ""),
            timestamp=datetime.fromisoformat(data['timestamp']) if 'timestamp' in data else datetime.now(UTC),
            modality=data.get('modality'),
            type=data.get('type'),
            status=data.get('status'),
            tags=data.get('tags', []),
            embedding=data.get('embedding'),
            location=data.get('location'),
            source_role=data.get('source_role'),
            source_name=data.get('source_name'),
            source_id=data.get('source_id', uuid4().hex),
            edges=[MemoryLink.from_dict(edge) for edge in data.get('edges', [])],
            metadata=data.get('metadata', {})
        )
    except Exception as e:
        # logger.error(f"Error deserializing MemoryUnit from dict: {e}")
        raise e
to_dict()

Converts the MemoryUnit instance to a dictionary format for serialization. Each field is handled explicitly to ensure proper serialization.

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary representation of the MemoryUnit.

Source code in src/aeiva/cognition/memory/memory_unit.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def to_dict(self) -> Dict[str, Any]:
    """
    Converts the MemoryUnit instance to a dictionary format for serialization.
    Each field is handled explicitly to ensure proper serialization.

    Returns:
        Dict[str, Any]: A dictionary representation of the MemoryUnit.
    """
    return {
        "id": self.id,
        "content": self.content,
        "timestamp": self.timestamp.isoformat(),  # Convert datetime to string
        "modality": self.modality,
        "type": self.type,
        "status": self.status,
        "tags": self.tags,
        "embedding": self.embedding,
        "location": self.location,
        "source_role": self.source_role,
        "source_name": self.source_name,
        "source_id": self.source_id,
        "edges": [edge.to_dict() for edge in self.edges],  # Serialize each MemoryLink
        "metadata": self.metadata
    }

memory_utils

derive_content(derivation_type, data)

You are a creative assistant capable of deriving new content based on specified types. Your task is to derive a {derivation_type} from the provided combined content.

Source code in src/aeiva/cognition/memory/memory_utils.py
15
16
17
18
19
20
21
22
@simple(model='gpt-4', temperature=0.7)
def derive_content(derivation_type: str, data: str) -> str:
    """
    You are a creative assistant capable of deriving new content based on specified types.
    Your task is to derive a {derivation_type} from the provided combined content.
    """
    result = f"Derive a {derivation_type} from the following content:\n{data}"
    return result
extract_entities_relationships(data)

You are an intelligent assistant skilled in natural language processing. Your task is to extract entities and the relationships between them from the provided content.

Source code in src/aeiva/cognition/memory/memory_utils.py
 6
 7
 8
 9
10
11
12
13
@simple(model='gpt-4', temperature=0.7)
def extract_entities_relationships(data: Any) -> str:
    """
    You are an intelligent assistant skilled in natural language processing.
    Your task is to extract entities and the relationships between them from the provided content.
    """
    result = f"Extract entities and relationships from the following content:\n{data}"
    return result

storage_config

StorageConfig dataclass

Bases: BaseConfig

Configuration class for the Memory storage.

Attributes:

Name Type Description
vector_db_config DatabaseConfig

Configuration for the vector database.

graph_db_config Optional[DatabaseConfig]

Configuration for the graph database.

relational_db_config Optional[DatabaseConfig]

Configuration for the relational database.

Source code in src/aeiva/cognition/memory/storage_config.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@dataclass
class StorageConfig(BaseConfig):
    """
    Configuration class for the Memory storage.

    Attributes:
        vector_db_config (DatabaseConfig): Configuration for the vector database.
        graph_db_config (Optional[DatabaseConfig]): Configuration for the graph database.
        relational_db_config (Optional[DatabaseConfig]): Configuration for the relational database.
    """
    vector_db_provider: str = field(
        metadata={"help": "Vector database provider name."}
    )
    vector_db_config: BaseConfig = field(
        metadata={"help": "Configuration for the vector database."}
    )
    graph_db_provider: Optional[str] = field(
        default=None,
        metadata={"help": "Graph database provider name."}
    )
    graph_db_config: Optional[BaseConfig] = field(
        default=None,
        metadata={"help": "Configuration for the graph database."}
    )
    relational_db_provider: Optional[str] = field(
        default=None,
        metadata={"help": "Relational database provider name."}
    )
    relational_db_config: Optional[BaseConfig] = field(
        default=None,
        metadata={"help": "Configuration for the relational database."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Perform any necessary validation
        if not self.vector_db_config:
            raise ValueError("Vector database configuration must be provided.")

observation

Observation

Represents a processed input from the PerceptionSystem.

Source code in src/aeiva/cognition/observation.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Observation:
    """
    Represents a processed input from the PerceptionSystem.
    """
    def __init__(self, data: Any, modality: str = 'text', timestamp: Optional[datetime] = None, metadata: Optional[Dict[str, Any]] = None):
        self.data = data  # The processed data (e.g., text)
        self.modality = modality
        self.timestamp = timestamp or datetime.now()
        self.metadata = metadata or {}

    def to_dict(self) -> Dict[str, Any]:
        return {
            'data': self.data,
            'modality': self.modality,
            'timestamp': self.timestamp.isoformat(),
            'metadata': self.metadata
        }

thought

Thought

Represents the output from the Brain after processing an Observation.

Source code in src/aeiva/cognition/thought.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Thought:
    """
    Represents the output from the Brain after processing an Observation.
    """
    def __init__(self, content: Any, modality: str = 'text', timestamp: Optional[datetime] = None, metadata: Optional[Dict[str, Any]] = None):
        self.content = content  # The thought content (e.g., text)
        self.modality = modality
        self.timestamp = timestamp or datetime.now()
        self.metadata = metadata or {}

    def to_dict(self) -> Dict[str, Any]:
        return {
            'content': self.content,
            'modality': self.modality,
            'timestamp': self.timestamp.isoformat(),
            'metadata': self.metadata
        }

world_model

world_model

WorldModel

Bases: ABC

Abstract base class representing the World Model system of an agent.

The World Model maintains an internal representation of the environment, enabling the agent to understand, predict, and interact with its surroundings effectively.

Attributes:

Name Type Description
config Any

Configuration settings for the World Model system.

state Any

The internal state of the World Model system.

Source code in src/aeiva/cognition/world_model/world_model.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class WorldModel(ABC):
    """
    Abstract base class representing the World Model system of an agent.

    The World Model maintains an internal representation of the environment, enabling the agent
    to understand, predict, and interact with its surroundings effectively.

    Attributes:
        config (Any): Configuration settings for the World Model system.
        state (Any): The internal state of the World Model system.
    """

    def __init__(self, config: Any):
        """
        Initialize the World Model system with the provided configuration.

        Args:
            config (Any): Configuration settings for the World Model system.
        """
        self.config = config
        self.state = self.init_state()

    @abstractmethod
    def init_state(self) -> Any:
        """
        Initialize the internal state of the World Model system.

        This method should set up the initial state required for the World Model system's operations.

        Returns:
            Any: The initial state of the World Model system.
        """
        pass

    @abstractmethod
    def setup(self) -> None:
        """
        Asynchronously set up the World Model system's components.

        This method should initialize any necessary components or resources based on the provided configuration.

        Raises:
            ConfigurationError: If the configuration is invalid or incomplete.
        """
        pass

    @abstractmethod
    async def update(self, observation: Any) -> None:
        """
        Asynchronously update the world model based on new observations.

        Args:
            observation (Any): The new observation to incorporate into the world model.

        Raises:
            UpdateError: If updating the world model fails.
        """
        pass

    @abstractmethod
    async def query(self, query: Any) -> Any:
        """
        Asynchronously query the world model for specific information.

        Args:
            query (Any): The query or criteria to retrieve specific information from the world model.

        Returns:
            Any: The information retrieved from the world model.

        Raises:
            QueryError: If the query process fails.
        """
        pass

    def get_current_state(self) -> Any:
        """
        Retrieve the current internal state of the World Model system.

        Returns:
            Any: The current internal state.
        """
        return self.state

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during world model operations.

        This method can be overridden to implement custom error handling logic.

        Args:
            error (Exception): The exception that was raised.
        """
        # Default error handling: log the error
        print(f"WorldModel system encountered an error: {error}")
__init__(config)

Initialize the World Model system with the provided configuration.

Parameters:

Name Type Description Default
config Any

Configuration settings for the World Model system.

required
Source code in src/aeiva/cognition/world_model/world_model.py
19
20
21
22
23
24
25
26
27
def __init__(self, config: Any):
    """
    Initialize the World Model system with the provided configuration.

    Args:
        config (Any): Configuration settings for the World Model system.
    """
    self.config = config
    self.state = self.init_state()
get_current_state()

Retrieve the current internal state of the World Model system.

Returns:

Name Type Description
Any Any

The current internal state.

Source code in src/aeiva/cognition/world_model/world_model.py
82
83
84
85
86
87
88
89
def get_current_state(self) -> Any:
    """
    Retrieve the current internal state of the World Model system.

    Returns:
        Any: The current internal state.
    """
    return self.state
handle_error(error)

Handle errors that occur during world model operations.

This method can be overridden to implement custom error handling logic.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/cognition/world_model/world_model.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during world model operations.

    This method can be overridden to implement custom error handling logic.

    Args:
        error (Exception): The exception that was raised.
    """
    # Default error handling: log the error
    print(f"WorldModel system encountered an error: {error}")
init_state() abstractmethod

Initialize the internal state of the World Model system.

This method should set up the initial state required for the World Model system's operations.

Returns:

Name Type Description
Any Any

The initial state of the World Model system.

Source code in src/aeiva/cognition/world_model/world_model.py
29
30
31
32
33
34
35
36
37
38
39
@abstractmethod
def init_state(self) -> Any:
    """
    Initialize the internal state of the World Model system.

    This method should set up the initial state required for the World Model system's operations.

    Returns:
        Any: The initial state of the World Model system.
    """
    pass
query(query) abstractmethod async

Asynchronously query the world model for specific information.

Parameters:

Name Type Description Default
query Any

The query or criteria to retrieve specific information from the world model.

required

Returns:

Name Type Description
Any Any

The information retrieved from the world model.

Raises:

Type Description
QueryError

If the query process fails.

Source code in src/aeiva/cognition/world_model/world_model.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@abstractmethod
async def query(self, query: Any) -> Any:
    """
    Asynchronously query the world model for specific information.

    Args:
        query (Any): The query or criteria to retrieve specific information from the world model.

    Returns:
        Any: The information retrieved from the world model.

    Raises:
        QueryError: If the query process fails.
    """
    pass
setup() abstractmethod

Asynchronously set up the World Model system's components.

This method should initialize any necessary components or resources based on the provided configuration.

Raises:

Type Description
ConfigurationError

If the configuration is invalid or incomplete.

Source code in src/aeiva/cognition/world_model/world_model.py
41
42
43
44
45
46
47
48
49
50
51
@abstractmethod
def setup(self) -> None:
    """
    Asynchronously set up the World Model system's components.

    This method should initialize any necessary components or resources based on the provided configuration.

    Raises:
        ConfigurationError: If the configuration is invalid or incomplete.
    """
    pass
update(observation) abstractmethod async

Asynchronously update the world model based on new observations.

Parameters:

Name Type Description Default
observation Any

The new observation to incorporate into the world model.

required

Raises:

Type Description
UpdateError

If updating the world model fails.

Source code in src/aeiva/cognition/world_model/world_model.py
53
54
55
56
57
58
59
60
61
62
63
64
@abstractmethod
async def update(self, observation: Any) -> None:
    """
    Asynchronously update the world model based on new observations.

    Args:
        observation (Any): The new observation to incorporate into the world model.

    Raises:
        UpdateError: If updating the world model fails.
    """
    pass

command

aeiva_chat_gradio

We can run the command like below: (specify your own config file path)

aeiva-chat-gradio --config configs/agent_config.yaml

run(config, verbose)

Starts the Aeiva chat Gradio interface with the provided configuration.

Source code in src/aeiva/command/aeiva_chat_gradio.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
@click.command(name="aeiva-chat-gradio")
@click.option('--config', '-c', default=str(DEFAULT_CONFIG_PATH),
              help='Path to the configuration file (YAML or JSON).',
              type=click.Path(exists=True, dir_okay=False))
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging.')
def run(config, verbose):
    """
    Starts the Aeiva chat Gradio interface with the provided configuration.
    """
    # Setup logging
    logger = setup_logging(DEFAULT_LOG_PATH, verbose)

    # Load environment variables (API keys, etc.)
    load_dotenv()

    logger.info(f"Loading configuration from {config}")
    config_dict = from_json_or_yaml(config)

    # Initialize the Agent
    try:
        agent = Agent(config_dict)
        agent.setup()
        logger.info("Agent initialized successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize Agent: {e}")
        click.echo(f"Error: Failed to initialize Agent: {e}")
        sys.exit(1)

    # Function to run the Agent's run method in a separate thread
    def run_agent(agent_instance):
        try:
            asyncio.run(agent_instance.run())
        except Exception as e:
            logger.error(f"Error running Agent: {e}")

    # Start the Agent in a separate daemon thread
    agent_thread = threading.Thread(target=run_agent, args=(agent,), daemon=True)
    agent_thread.start()
    logger.info("Agent run thread started.")

    # Initialize a thread-safe queue to receive responses from the Agent
    response_queue = queue.Queue()

    # Define a handler for 'response.gradio' events
    def handle_response_gradio(event: Event):
        response = event.payload
        response_queue.put_nowait(response)  # Put response into the thread-safe queue
        logger.info(f"Received 'response.gradio' event: {response}")

    # Register the handler with the Agent's EventBus
    agent.event_bus.on('response.gradio')(handle_response_gradio)
    logger.info("Registered handler for 'response.gradio' events.")

    # Validate and start Neo4j
    neo4j_home = os.getenv('NEO4J_HOME')
    if not neo4j_home:
        logger.error("NEO4J_HOME environment variable is not set.")
        click.echo("Error: NEO4J_HOME environment variable is not set.")
        sys.exit(1)

    validate_neo4j_home(logger, neo4j_home)
    neo4j_process = start_neo4j(logger, neo4j_home)

    # Register signal handlers to ensure Neo4j stops gracefully
    for sig in [signal.SIGINT, signal.SIGTERM]:
        signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process))

    # Define handlers for multimodal inputs

    def handle_image_upload(image: Image.Image):
        if image is not None:
            timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
            image_path = f"uploads/uploaded_image_{timestamp}.jpg"
            try:
                image.save(image_path)
                logger.info(f"Image uploaded and saved to {image_path}")
                return "User uploaded an image."
            except Exception as e:
                logger.error(f"Error saving uploaded image: {e}")
                return "Failed to upload image."
        return ""

    def handle_video_upload(video):
        if video is not None:
            timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
            video_path = f"uploads/uploaded_video_{timestamp}.mp4"
            try:
                with open(video_path, "wb") as f:
                    f.write(video.read())
                logger.info(f"Video uploaded and saved to {video_path}")
                return "User uploaded a video."
            except Exception as e:
                logger.error(f"Error saving uploaded video: {e}")
                return "Failed to upload video."
        return ""

    def handle_audio_upload(audio):
        if audio is not None:
            try:
                sample_rate, audio_data = audio
                # Normalize audio_data to float32 in the range -1.0 to 1.0
                audio_data_normalized = audio_data.astype(np.float32) / np.abs(audio_data).max()
                timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
                audio_path = f"uploads/uploaded_audio_{timestamp}.wav"
                sf.write(audio_path, audio_data_normalized, sample_rate, subtype='PCM_16')
                logger.info(f"Audio uploaded and saved to {audio_path}")
                return "User uploaded an audio file."
            except Exception as e:
                logger.error(f"Error saving uploaded audio: {e}")
                return "Failed to upload audio."
        return ""

    def handle_upload(file):
        """
        Handles file uploads and delegates to specific handlers based on file type.

        Args:
            file: Uploaded file object.

        Returns:
            str: Message indicating the upload status.
        """
        if file is None:
            return ""
        if file.type.startswith("image"):
            return handle_image_upload(file)
        elif file.type.startswith("video"):
            return handle_video_upload(file)
        elif file.type.startswith("audio"):
            return handle_audio_upload(file)
        else:
            logger.warning(f"Unsupported file type uploaded: {file.type}")
            return "Unsupported file type uploaded."

    def clear_media():
        """
        Clears the uploaded media paths.
        """
        # Implement any necessary logic to clear media paths or data
        logger.info("Cleared uploaded media paths.")
        return ""

    async def bot(user_input, history):
        """
        Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.
        """
        if agent is None:
            logger.error("Agent is not initialized.")
            history.append({"role": "assistant", "content": "Agent is not initialized."})
            yield history, ''
            return

        try:
            # Append user's message to history
            history.append({"role": "user", "content": user_input})
            # Append an empty assistant response
            history.append({"role": "assistant", "content": ""})
            yield history, ''  # Display the user's message
            logger.info(f"User input appended to history: {user_input}")

            stream = config_dict["llm_gateway_config"]["llm_stream"]
            use_async = config_dict["llm_gateway_config"]["llm_use_async"]

            # Emit the 'perception.gradio' event with stream=True
            emit_future = asyncio.run_coroutine_threadsafe(
                agent.event_bus.emit('perception.gradio', payload=user_input),
                agent.event_bus.loop
            )
            emit_future.result()  # Ensure the event is emitted
            logger.info(f"Emitted 'perception.gradio' event with payload: {user_input} | Stream: {stream}")

            assistant_message = ''
            if stream:
                while True:
                    try:
                        # Non-blocking response retrieval from the thread-safe queue with timeout
                        response = await asyncio.wait_for(
                            asyncio.to_thread(response_queue.get, True, 30),
                            timeout=30
                        )
                        logger.info(f"Retrieved response from queue: {response}")
                        if response == "<END_OF_RESPONSE>":
                            logger.info("Received end of response signal.")
                            break
                        assistant_message += response
                        # Create a new history list to ensure Gradio detects the update
                        new_history = history.copy()
                        new_history[-1]["content"] = assistant_message
                        logger.info(f"Yielding updated history: {new_history}")
                        yield new_history, ''
                    except asyncio.TimeoutError:
                        logger.warning("Timeout: No response received from Agent.")
                        # Create a new history list to ensure Gradio detects the update
                        new_history = history.copy()
                        new_history[-1]["content"] = "I'm sorry, I didn't receive a response in time."
                        yield new_history, ''
                        break
            else:
                try:
                    # Non-blocking response retrieval from the thread-safe queue with timeout
                    response = await asyncio.wait_for(
                        asyncio.to_thread(response_queue.get, True, 30),
                        timeout=30
                    )
                    logger.info(f"Retrieved response from queue: {response}")
                    assistant_message += response
                    # Create a new history list to ensure Gradio detects the update
                    new_history = history.copy()
                    new_history[-1]["content"] = assistant_message
                    logger.info(f"Yielding updated history: {new_history}")
                    yield new_history, ''
                except asyncio.TimeoutError:
                    logger.warning("Timeout: No response received from Agent.")
                    # Create a new history list to ensure Gradio detects the update
                    new_history = history.copy()
                    new_history[-1]["content"] = "I'm sorry, I didn't receive a response in time."
                    yield new_history, ''

        except Exception as e:
            logger.error(f"Unexpected Error in bot function: {e}")
            # Create a new history list to ensure Gradio detects the update
            new_history = history.copy()
            new_history[-1]["content"] = "An unexpected error occurred."
            yield new_history, ''

    def launch_gradio_interface():
        """
        Main gradio interface.
        """
        with gr.Blocks(title="Multimodal LLM Chatbot with Tools") as demo:
            # Header Section
            gr.Markdown("""
            <h1 align="center">
                <a href="https://github.com/chatsci/Aeiva">
                    <img src="https://i.ibb.co/P4zQHDk/aeiva-1024.png",
                    alt="Aeiva" border="0" style="margin: 0 auto; height: 200px;" />
                </a>
            </h1>

            <h2 align="center">
                AEIVA: An Evolving Intelligent Virtual Assistant
            </h2>

            <h5 align="center">
                If you like our project, please give us a star ✨ on Github for the latest update.
            </h5>

            <div align="center">
                <div style="display:flex; gap: 0.25rem;" align="center">
                    <a href='https://github.com/chatsci/Aeiva'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
                    <a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
                    <a href='https://github.com/chatsci/Aeiva/stargazers'><img src='https://img.shields.io/github/stars/chatsci/Aeiva.svg?style=social'></a>
                </div>
            </div>
            """)

            # Main Layout: Two Columns
            with gr.Row():
                # Left Column: Parameter Settings and Multimodal Inputs
                with gr.Column(scale=1, min_width=700):
                    # Parameter Settings Tab
                    with gr.Tab(label="Parameter Setting"):
                        gr.Markdown("# Parameters")
                        top_p = gr.Slider(
                            minimum=0,
                            maximum=1.0,
                            value=0.95,
                            step=0.05,
                            interactive=True,
                            label="Top-p"
                        )
                        temperature = gr.Slider(
                            minimum=0.1,
                            maximum=2.0,
                            value=1.0,
                            step=0.1,
                            interactive=True,
                            label="Temperature"
                        )
                        max_length_tokens = gr.Slider(
                            minimum=0,
                            maximum=512,
                            value=512,
                            step=8,
                            interactive=True,
                            label="Max Generation Tokens"
                        )
                        max_context_length_tokens = gr.Slider(
                            minimum=0,
                            maximum=4096,
                            value=2048,
                            step=128,
                            interactive=True,
                            label="Max History Tokens"
                        )

                    # Multimodal Inputs Section
                    with gr.Row():
                        imagebox = gr.Image(type="pil", label="Upload Image")
                        videobox = gr.File(label="Upload Video", file_types=["video"])
                        audiobox = gr.Audio(label="Upload Audio", type="numpy")

                    with gr.Row():
                        record_videobox = gr.Video(label="Record Video")
                        record_audiobox = gr.Audio(label="Record Audio")

                    # Clear Media Button
                    with gr.Row():
                        clear_media_btn = gr.Button("🧹 Clear Media", variant="secondary")

                # Right Column: Chat Interface and Action Buttons
                with gr.Column(scale=1, min_width=700):
                    # Chatbot Component
                    chatbot = gr.Chatbot(
                        [],
                        type="messages",  # Specify type as 'messages'
                        elem_id="chatbot",
                        height=730
                    )

                    # Input Textbox and Upload Button
                    with gr.Row():
                        with gr.Column(scale=4, min_width=300):
                            txt = gr.Textbox(
                                show_label=False,
                                placeholder="Enter text and press enter, or upload an image/video/audio",
                                lines=1,
                                elem_classes=["input-textbox"]  # Assign a CSS class for styling
                            )
                        with gr.Column(scale=1, min_width=100):
                            btn = gr.UploadButton("📁", file_types=["image", "video", "audio"], elem_classes=["upload-button"])
                            # Changed the button label to an icon for a more compact look

                    # Action Buttons Placed Below the Input Box
                    with gr.Row():
                        upvote_btn = gr.Button("👍 Upvote", interactive=True)
                        downvote_btn = gr.Button("👎 Downvote", interactive=True)
                        flag_btn = gr.Button("⚠️ Flag", interactive=True)
                        regenerate_btn = gr.Button("🔄 Regenerate", interactive=True)
                        clear_history_btn = gr.Button("🗑️ Clear History", interactive=True)
                        new_conv_btn = gr.Button("🧹 New Conversation", interactive=True)
                        del_last_turn_btn = gr.Button("🗑️ Remove Last Turn", interactive=True)

            # Define interactions

            # Text input submission with streaming
            txt.submit(
                bot,
                inputs=[txt, chatbot],
                outputs=[chatbot, txt],
                queue=True,    # Enable queue for better performance
                # stream=True    # Enable streaming (already handled in the bot function)
            )
            # Removed the .then callback to prevent layout shifts

            # File upload (image/video/audio)
            btn.upload(
                handle_upload,
                inputs=btn,
                outputs=txt,  # Set message in textbox to trigger bot
                queue=True
            )

            # Image upload
            imagebox.upload(
                handle_image_upload,
                inputs=imagebox,
                outputs=txt,  # Set message in textbox to trigger bot
                queue=True
            )

            # Video upload
            videobox.upload(
                handle_video_upload,
                inputs=videobox,
                outputs=txt,  # Set message in textbox to trigger bot
                queue=True
            )

            # Audio upload
            audiobox.upload(
                handle_audio_upload,
                inputs=audiobox,
                outputs=txt,  # Set message in textbox to trigger bot
                queue=True
            )

            # Record Video
            record_videobox.change(
                handle_video_upload,
                inputs=record_videobox,
                outputs=txt,  # Set message in textbox to trigger bot
                queue=True
            )

            # Record Audio
            record_audiobox.change(
                handle_audio_upload,
                inputs=record_audiobox,
                outputs=txt,  # Set message in textbox to trigger bot
                queue=True
            )

            # Clear Media Button
            clear_media_btn.click(
                clear_media,
                inputs=None,
                outputs=None,
                queue=False
            )

            # Action Buttons Functionality

            # Clear History
            clear_history_btn.click(
                lambda: ([], ""),
                inputs=None,
                outputs=[chatbot, txt],
                queue=False
            )

            # New Conversation
            new_conv_btn.click(
                lambda: ([], ""),
                inputs=None,
                outputs=[chatbot, txt],
                queue=False
            )

            # Remove Last Turn (Removes the last user and assistant messages)
            del_last_turn_btn.click(
                lambda history: history[:-2] if len(history) >= 2 else history,
                inputs=chatbot,
                outputs=chatbot,
                queue=False
            )

        # Launch the Gradio interface
        demo.launch(share=True)

    # Launch aeiva chat gradio
    launch_gradio_interface()

aeiva_chat_terminal

We can run the command like below: (specify your own config file path)

aeiva-chat-terminal --config configs/agent_config.yaml

run(config, verbose)

Starts the Aeiva chat terminal with the provided configuration.

Source code in src/aeiva/command/aeiva_chat_terminal.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@click.command()
@click.option('--config', '-c', default=str(DEFAULT_CONFIG_PATH),
              help='Path to the configuration file (YAML or JSON).',
              type=click.Path(exists=True, dir_okay=False))
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging.')
def run(config, verbose):
    """
    Starts the Aeiva chat terminal with the provided configuration.
    """
    # Setup logging
    logger = setup_logging(DEFAULT_LOG_PATH, verbose)

    click.echo(f"Loading configuration from {config}")
    config_path = Path(config)

    # Parse the configuration file with error handling
    try:
        config_data = from_json_or_yaml(config_path)
    except Exception as e:
        logger.error(f"Failed to parse configuration file: {e}")
        click.echo(f"Error: Failed to parse configuration file: {e}")
        sys.exit(1)

    # Retrieve NEO4J_HOME from environment variables
    neo4j_home = os.getenv('NEO4J_HOME')
    if not neo4j_home:
        logger.error("NEO4J_HOME is not set in the environment.")
        click.echo("Error: NEO4J_HOME is not set in the environment. Please set it in your shell configuration (e.g., .bashrc or .zshrc).")
        sys.exit(1)

    # Validate NEO4J_HOME path
    validate_neo4j_home(logger, neo4j_home)

    # Start Neo4j
    neo4j_process = start_neo4j(logger, neo4j_home)

    # Register signal handlers to ensure Neo4j stops gracefully
    signal.signal(signal.SIGINT, lambda s, f: handle_exit(s, f, neo4j_process))
    signal.signal(signal.SIGTERM, lambda s, f: handle_exit(s, f, neo4j_process))

    # Start the Agent
    try:
        agent = Agent(config_data)
        agent.setup()
        asyncio.run(agent.run())
    except KeyboardInterrupt:
        logger.info("Agent execution interrupted by user.")
        click.echo("\nAgent execution interrupted by user.")
    except Exception as e:
        logger.error(f"An error occurred during agent execution: {e}")
        click.echo(f"An error occurred during agent execution: {e}")
    finally:
        # # Perform any necessary cleanup
        # try:
        #     agent.cognition_components['memory'].delete_all()
        #     logger.info("All memory units deleted during cleanup.")
        # except NotImplementedError as nie:
        #     logger.warning(f"Delete All feature not implemented: {nie}")
        # except Exception as e:
        #     logger.error(f"Error during cleanup: {e}")
        #     click.echo("Failed to delete all memory units.")

        # Stop Neo4j
        stop_neo4j(logger, neo4j_process)
        logger.info("Cleanup completed.")

aeiva_server

run(config, host, port, verbose)

Starts the Aeiva Agent Server using FastAPI.

Source code in src/aeiva/command/aeiva_server.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@click.command(name="aeiva-server")
@click.option(
    '--config', '-c',
    default=None,
    help='Path to the configuration file (YAML or JSON).',
    type=click.Path(exists=True, dir_okay=False)
)
@click.option(
    '--host', '-H',
    default="0.0.0.0",
    help='Host address to run the server on.',
    show_default=True
)
@click.option(
    '--port', '-p',
    default=8000,
    help='Port number to run the server on.',
    show_default=True
)
@click.option(
    '--verbose', '-v',
    is_flag=True,
    help='Enable verbose logging.'
)
def run(config, host, port, verbose):
    """
    Starts the Aeiva Agent Server using FastAPI.
    """
    # Setup logging
    logger = setup_logging(get_log_dir() / 'aeiva-server.log', verbose)

    # Load configuration
    if config is None:
        PACKAGE_ROOT = get_package_root()
        config_path = PACKAGE_ROOT / 'configs' / 'agent_config.yaml'
    else:
        config_path = Path(config)

    logger.info(f"Loading configuration from {config_path}")
    config_dict = from_json_or_yaml(config_path)

    # Validate and start Neo4j
    neo4j_home = os.getenv('NEO4J_HOME')
    if not neo4j_home:
        logger.error("NEO4J_HOME environment variable is not set.")
        click.echo("Error: NEO4J_HOME environment variable is not set.")
        sys.exit(1)

    validate_neo4j_home(logger, neo4j_home)
    neo4j_process = start_neo4j(logger, neo4j_home)

    # Initialize the Agent
    try:
        agent = Agent(config_dict)
        agent.setup()
        logger.info("Agent initialized successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize Agent: {e}")
        click.echo(f"Error: Failed to initialize Agent: {e}")
        stop_neo4j(logger, neo4j_process)
        sys.exit(1)

    # Define the FastAPI app with lifespan
    @asynccontextmanager
    async def lifespan(app: FastAPI):
        app.state.agent = agent
        logger.info("Agent has been initialized and is ready to receive messages.")
        try:
            yield
        finally:
            logger.info("Shutting down the agent server.")
            # If the Agent class has a shutdown method, call it here
            if hasattr(app.state.agent, 'shutdown'):
                await app.state.agent.shutdown()
            stop_neo4j(logger, neo4j_process)
            logger.info("Agent server shut down gracefully.")

    app = FastAPI(lifespan=lifespan)

    # Enable CORS for all origins (for development purposes)
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],  # Adjust in production
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    # Define the endpoint
    @app.post("/process_text", response_model=MessageResponse)
    async def process_text(request: MessageRequest):
        if not request.message:
            raise HTTPException(status_code=400, detail="No message provided")

        logger.info(f"Received message: {request.message}")

        # Process the message using the agent
        try:
            response_text = await app.state.agent.process_input(request.message)
            logger.info(f"Agent response: {response_text}")
            return MessageResponse(response=response_text)
        except Exception as e:
            logger.error(f"Error processing input: {e}")
            raise HTTPException(status_code=500, detail="Internal Server Error")

    # Register signal handlers for graceful shutdown using handle_exit
    for sig in [signal.SIGINT, signal.SIGTERM]:
        signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process))

    # Run the FastAPI app using Uvicorn
    try:
        logger.info(f"Starting server at http://{host}:{port}")
        uvicorn.run(app, host=host, port=port)
    except Exception as e:
        logger.error(f"Server encountered an error: {e}")
        handle_exit(None, None, logger, neo4j_process)  # Ensure cleanup on exception
        sys.exit(1)
    finally:
        logger.info("Server has been stopped.")

command_utils

Here we put util functions related to database, logging and so on for different aeiva commands execution.

get_log_dir()

Determines a suitable path for the log file. Logs are stored in the user's home directory under '.aeiva/logs/'.

Source code in src/aeiva/command/command_utils.py
24
25
26
27
28
29
30
31
32
def get_log_dir():
    """
    Determines a suitable path for the log file.
    Logs are stored in the user's home directory under '.aeiva/logs/'.
    """
    home_dir = Path.home()
    log_dir = home_dir / '.aeiva' / 'logs'  # Log saved to `~/.aeiva/logs/`
    log_dir.mkdir(parents=True, exist_ok=True)  # Ensure the log directory exists
    return log_dir

get_package_root()

Determines the root path of the 'aeiva' package.

Source code in src/aeiva/command/command_utils.py
16
17
18
19
20
21
22
def get_package_root():
    """
    Determines the root path of the 'aeiva' package.
    """
    aeiva_path = Path(importlib_resources.files("aeiva"))
    package_root = aeiva_path.parents[1]
    return package_root.resolve()

handle_exit(signum, frame, logger, neo4j_process)

Handles termination signals to ensure Neo4j is stopped gracefully.

Source code in src/aeiva/command/command_utils.py
134
135
136
137
138
139
140
141
def handle_exit(signum, frame, logger, neo4j_process):
    """
    Handles termination signals to ensure Neo4j is stopped gracefully.
    """
    logger.info(f"Received signal {signum}. Shutting down Neo4j.")
    click.echo(f"\nReceived signal {signum}. Shutting down Neo4j.")
    stop_neo4j(logger, neo4j_process)
    sys.exit(0)

setup_logging(log_file, verbose=False)

Sets up logging to both file and console.

Source code in src/aeiva/command/command_utils.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def setup_logging(log_file, verbose=False):
    """
    Sets up logging to both file and console.
    """
    logger = get_logger(__name__, level="DEBUG" if verbose else "INFO")

    # Create a file handler
    file_handler = logging.FileHandler(log_file, mode='a')
    file_handler.setLevel(logging.DEBUG if verbose else logging.INFO)

    # Create a console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)

    # Create a logging format
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    # Add handlers to the logger
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    return logger

start_neo4j(logger, neo4j_home)

Starts the Neo4j database as a subprocess.

Source code in src/aeiva/command/command_utils.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def start_neo4j(logger, neo4j_home):
    """
    Starts the Neo4j database as a subprocess.
    """
    neo4j_command = [os.path.join(neo4j_home, 'bin', 'neo4j'), 'console']
    try:
        neo4j_process = subprocess.Popen(
            neo4j_command,
            stdout=subprocess.DEVNULL,  # Suppress stdout
            stderr=subprocess.DEVNULL,  # Suppress stderr
            stdin=subprocess.DEVNULL,   # Prevent Neo4j from waiting for input
            preexec_fn=os.setsid       # Start the process in a new session
        )
        logger.info("Neo4j database started successfully.")
        click.echo("Neo4j database started successfully.")
        return neo4j_process
    except FileNotFoundError:
        logger.error(f"Neo4j executable not found in {neo4j_command}.")
        click.echo(f"Error: Neo4j executable not found in {neo4j_command}.")
        sys.exit(1)
    except Exception as e:
        logger.error(f"Failed to start Neo4j: {e}")
        click.echo(f"Error: Failed to start Neo4j: {e}")
        sys.exit(1)

stop_neo4j(logger, neo4j_process)

Stops the Neo4j database subprocess gracefully.

Source code in src/aeiva/command/command_utils.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def stop_neo4j(logger, neo4j_process):
    """
    Stops the Neo4j database subprocess gracefully.
    """
    try:
        # Check if the process is still running
        if neo4j_process.poll() is None:
            os.killpg(os.getpgid(neo4j_process.pid), signal.SIGINT)  # Send SIGINT for graceful shutdown
            logger.info("Sent SIGINT to Neo4j subprocess.")
            click.echo("Shutting down Neo4j...")
            neo4j_process.wait(timeout=15)  # Increased timeout to 15 seconds
            logger.info("Neo4j database stopped successfully.")
            click.echo("Neo4j database stopped successfully.")
        else:
            logger.warning("Neo4j subprocess is already terminated.")
            click.echo("Warning: Neo4j subprocess is already terminated.")
    except subprocess.TimeoutExpired:
        logger.error("Neo4j did not terminate within the timeout period.")
        click.echo("Error: Neo4j did not terminate within the timeout period.")
        # Optionally, force kill
        try:
            os.killpg(os.getpgid(neo4j_process.pid), signal.SIGKILL)
            neo4j_process.wait(timeout=5)
            logger.info("Neo4j database forcefully terminated.")
            click.echo("Neo4j database forcefully terminated.")
        except Exception as e:
            logger.error(f"Failed to forcefully terminate Neo4j: {e}")
            click.echo(f"Error: Failed to forcefully terminate Neo4j: {e}")
    except ProcessLookupError:
        logger.warning("Neo4j subprocess does not exist.")
        click.echo("Warning: Neo4j subprocess does not exist. It may have already terminated.")
    except Exception as e:
        logger.error(f"Error stopping Neo4j: {e}")
        click.echo(f"Error: Failed to stop Neo4j: {e}")

validate_neo4j_home(logger, neo4j_home)

Validates that the NEO4J_HOME path exists and contains the Neo4j executable.

Source code in src/aeiva/command/command_utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def validate_neo4j_home(logger, neo4j_home):
    """
    Validates that the NEO4J_HOME path exists and contains the Neo4j executable.
    """
    if not os.path.isdir(neo4j_home):
        logger.error(f"NEO4J_HOME path does not exist or is not a directory: {neo4j_home}")
        click.echo(f"Error: NEO4J_HOME path does not exist or is not a directory: {neo4j_home}")
        sys.exit(1)

    neo4j_executable = os.path.join(neo4j_home, 'bin', 'neo4j')
    if not os.path.isfile(neo4j_executable) or not os.access(neo4j_executable, os.X_OK):
        logger.error(f"Neo4j executable not found or not executable at: {neo4j_executable}")
        click.echo(f"Error: Neo4j executable not found or not executable at: {neo4j_executable}")
        sys.exit(1)

maid_chat

run(config, host, port, verbose)

Starts the Aeiva Agent Server and launches the Unity application.

Source code in src/aeiva/command/maid_chat.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@click.command(name="maid-chat")
@click.option(
    '--config', '-c',
    default=None,
    help='Path to the configuration file (YAML or JSON).',
    type=click.Path(exists=True, dir_okay=False)
)
@click.option(
    '--host', '-H',
    default="0.0.0.0",
    help='Host address to run the server on.',
    show_default=True
)
@click.option(
    '--port', '-p',
    default=8000,
    help='Port number to run the server on.',
    show_default=True
)
@click.option(
    '--verbose', '-v',
    is_flag=True,
    help='Enable verbose logging.'
)
def run(config, host, port, verbose):
    """
    Starts the Aeiva Agent Server and launches the Unity application.
    """
    # Setup logging
    logger = setup_logging(get_log_dir() / 'maid-chat.log', verbose)

    # Load configuration
    if config is None:
        PACKAGE_ROOT = get_package_root()
        config_path = PACKAGE_ROOT / 'configs' / 'agent_config.yaml'
    else:
        config_path = Path(config)

    logger.info(f"Loading configuration from {config_path}")
    config_dict = from_json_or_yaml(config_path)

    # Validate and start Neo4j
    neo4j_home = os.getenv('NEO4J_HOME')
    if not neo4j_home:
        logger.error("NEO4J_HOME environment variable is not set.")
        click.echo("Error: NEO4J_HOME environment variable is not set.")
        sys.exit(1)

    validate_neo4j_home(logger, neo4j_home)
    neo4j_process = start_neo4j(logger, neo4j_home)

    # Initialize the Agent
    try:
        agent = Agent(config_dict)
        agent.setup()
        logger.info("Agent initialized successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize Agent: {e}")
        click.echo(f"Error: Failed to initialize Agent: {e}")
        stop_neo4j(logger, neo4j_process)
        sys.exit(1)

    # Read MAID_HOME environment variable
    maid_home = os.getenv('MAID_HOME')
    if not maid_home:
        logger.error("MAID_HOME environment variable is not set.")
        click.echo("Error: MAID_HOME environment variable is not set.")
        stop_neo4j(logger, neo4j_process)
        sys.exit(1)

    maid_home_path = Path(maid_home)
    if not maid_home_path.exists():
        logger.error(f"Unity application not found at MAID_HOME: {maid_home}")
        click.echo(f"Error: Unity application not found at MAID_HOME: {maid_home}")
        stop_neo4j(logger, neo4j_process)
        sys.exit(1)

    # Start the Unity application
    unity_process = start_unity_app(str(maid_home_path), logger)
    if unity_process is None:
        stop_neo4j(logger, neo4j_process)
        sys.exit(1)

    # Define the FastAPI app with lifespan
    @asynccontextmanager
    async def lifespan(app: FastAPI):
        app.state.agent = agent
        logger.info("Agent has been initialized and is ready to receive messages.")
        try:
            yield
        finally:
            logger.info("Shutting down the agent server.")
            # If the Agent class has a shutdown method, call it here
            if hasattr(app.state.agent, 'shutdown'):
                await app.state.agent.shutdown()
            stop_neo4j(logger, neo4j_process)
            # Terminate the Unity application
            stop_unity_app(unity_process, logger)
            logger.info("Agent server shut down gracefully.")

    app = FastAPI(lifespan=lifespan)

    # Enable CORS for all origins (for development purposes)
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],  # Adjust in production
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    # Define the endpoint
    @app.post("/process_text", response_model=MessageResponse)
    async def process_text(request: MessageRequest):
        if not request.message:
            raise HTTPException(status_code=400, detail="No message provided")

        logger.info(f"Received message: {request.message}")

        # Process the message using the agent
        try:
            response_text = await app.state.agent.process_input(request.message)
            logger.info(f"Agent response: {response_text}")
            return MessageResponse(response=response_text)
        except Exception as e:
            logger.error(f"Error processing input: {e}")
            raise HTTPException(status_code=500, detail="Internal Server Error")

    # Register signal handlers for graceful shutdown using handle_exit
    for sig in [signal.SIGINT, signal.SIGTERM]:
        signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process, unity_process))

    # Run the FastAPI app using Uvicorn
    try:
        logger.info(f"Starting server at http://{host}:{port}")
        uvicorn.run(app, host=host, port=port)
    except Exception as e:
        logger.error(f"Server encountered an error: {e}")
        handle_exit(None, None, logger, neo4j_process, unity_process)  # Ensure cleanup on exception
        sys.exit(1)
    finally:
        logger.info("Server has been stopped.")

start_unity_app(maid_home, logger)

Starts the Unity application.

Parameters:

Name Type Description Default
maid_home str

Path to the Unity application executable.

required
logger Logger

Logger instance.

required

Returns:

Type Description
Optional[Popen]

Optional[subprocess.Popen]: The subprocess running the Unity application, or None if failed.

Source code in src/aeiva/command/maid_chat.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def start_unity_app(maid_home: str, logger: logging.Logger) -> Optional[subprocess.Popen]:
    """
    Starts the Unity application.

    Args:
        maid_home (str): Path to the Unity application executable.
        logger (logging.Logger): Logger instance.

    Returns:
        Optional[subprocess.Popen]: The subprocess running the Unity application, or None if failed.
    """
    try:
        unity_process = subprocess.Popen(
            [maid_home],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            preexec_fn=os.setsid  # Start the process in a new session
        )
        logger.info(f"Unity application started from {maid_home}.")
        click.echo(f"Unity application started from {maid_home}.")
        return unity_process
    except FileNotFoundError:
        logger.error(f"Unity application not found at {maid_home}.")
        click.echo(f"Error: Unity application not found at {maid_home}.")
        return None
    except Exception as e:
        logger.error(f"Failed to start Unity application: {e}")
        click.echo(f"Error: Failed to start Unity application: {e}.")
        return None

stop_unity_app(unity_process, logger)

Stops the Unity application gracefully.

Parameters:

Name Type Description Default
unity_process Popen

The subprocess running the Unity application.

required
logger Logger

Logger instance.

required
Source code in src/aeiva/command/maid_chat.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def stop_unity_app(unity_process: subprocess.Popen, logger: logging.Logger):
    """
    Stops the Unity application gracefully.

    Args:
        unity_process (subprocess.Popen): The subprocess running the Unity application.
        logger (logging.Logger): Logger instance.
    """
    try:
        os.killpg(os.getpgid(unity_process.pid), signal.SIGTERM)
        unity_process.wait(timeout=10)
        logger.info("Unity application terminated gracefully.")
        click.echo("Unity application terminated gracefully.")
    except Exception as e:
        logger.error(f"Error terminating Unity application: {e}")
        click.echo(f"Error: Failed to terminate Unity application: {e}.")

common

decorators

import_submodules(package, recursive=True)

Import all submodules of a module, recursively, including subpackages

Source code in src/aeiva/common/decorators.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def import_submodules(package, recursive=True):
    """ Import all submodules of a module, recursively, including subpackages """

    if isinstance(package, str):
        package = importlib.import_module(package)

    results = {}

    for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
        full_name = package.__name__ + "." + name
        results[full_name] = importlib.import_module(full_name)
        if recursive and is_pkg:
            results.update(import_submodules(full_name))

    return results

id_generator

IDGenerator

A simple class to generate unique IDs for distinct names.

Attributes:

Name Type Description
name_to_id dict

A dictionary to map names to IDs.

next_id int

The next ID to be assigned.

Source code in src/aeiva/common/id_generator.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class IDGenerator:
    """
    A simple class to generate unique IDs for distinct names.

    Attributes:
        name_to_id (dict): A dictionary to map names to IDs.
        next_id (int): The next ID to be assigned.
    """

    def __init__(self):
        """
        Constructs all the necessary attributes for the IDGenerator object.

        Attributes:
            name_to_id (dict): Initializes an empty dictionary to map names to IDs.
            next_id (int): Initializes the next ID to be assigned as 0.
        """
        self.name_to_id = {}
        self.next_id = 0

    def get_id(self, name: str) -> int:
        """
        Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.

        Parameters:
            name (str): The name for which the ID is required.

        Returns:
            int: The ID associated with the 'name'.
        """
        if name not in self.name_to_id:
            self.name_to_id[name] = self.next_id
            self.next_id += 1
        return self.name_to_id[name]
__init__()

Constructs all the necessary attributes for the IDGenerator object.

Attributes:

Name Type Description
name_to_id dict

Initializes an empty dictionary to map names to IDs.

next_id int

Initializes the next ID to be assigned as 0.

Source code in src/aeiva/common/id_generator.py
10
11
12
13
14
15
16
17
18
19
def __init__(self):
    """
    Constructs all the necessary attributes for the IDGenerator object.

    Attributes:
        name_to_id (dict): Initializes an empty dictionary to map names to IDs.
        next_id (int): Initializes the next ID to be assigned as 0.
    """
    self.name_to_id = {}
    self.next_id = 0
get_id(name)

Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.

Parameters:

Name Type Description Default
name str

The name for which the ID is required.

required

Returns:

Name Type Description
int int

The ID associated with the 'name'.

Source code in src/aeiva/common/id_generator.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def get_id(self, name: str) -> int:
    """
    Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.

    Parameters:
        name (str): The name for which the ID is required.

    Returns:
        int: The ID associated with the 'name'.
    """
    if name not in self.name_to_id:
        self.name_to_id[name] = self.next_id
        self.next_id += 1
    return self.name_to_id[name]

pipeline

Pipeline

This class is used to rurn a list of functions into a pipeline.

Source code in src/aeiva/common/pipeline.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Pipeline:
    r"""This class is used to rurn a list of functions into a pipeline."""
    def __init__(self, functions):
        self.functions = functions

    def run(self, *args, **kwargs):
        result = self.functions[0](*args, **kwargs)
        for f in self.functions[1:]:
            if isinstance(result, tuple):
                result = f(*result)
            else:
                result = f(result)
        return result

    def __call__(self, *args, **kwargs):
        return self.run(*args, **kwargs)

types

DataBatch

Bases: TypedDict

DataBatch is a batch of data items created by a dataloader.

Source code in src/aeiva/common/types.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class DataBatch(TypedDict):
    r"""DataBatch is a batch of data items created by a dataloader.
    """
    videos: Optional[torch.Tensor]  # videos representation
    audios: Optional[torch.Tensor]  # audios representation
    images: Optional[torch.Tensor]  # images representation
    input_ids: Optional[torch.Tensor]  # text token ids
    attention_mask: Optional[torch.Tensor]  # attention mask
    image_starts: Optional[torch.Tensor]  # image start token
    image_ends: Optional[torch.Tensor]  # image end token
    audio_starts: Optional[torch.Tensor]  # audio start token
    audio_ends: Optional[torch.Tensor]  # audio end token
    video_starts: Optional[torch.Tensor]  # video start token
    video_ends: Optional[torch.Tensor]  # video end token
    labels: Optional[torch.Tensor]  # labels

DataItem

Bases: TypedDict

DataItem is a dictionary that contains all the information for a single data item.

Source code in src/aeiva/common/types.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class DataItem(TypedDict):
    r"""DataItem is a dictionary that contains all the information for a single data item.
    """
    instruction: str  # instruction text
    input: Optional[str]  # input text
    output: Optional[str]  # output text
    text: Optional[str]  # text field. How it is formed depends on the task.

    image: Optional[str]  # image name or path
    transformed_image: Optional[torch.Tensor]  # transformed image tensor

    audio: Optional[str]  # audio name or path
    audio_mels: Optional[torch.Tensor]  # audio melspectrogram tensor

    video: Optional[str]  # video name or path
    sampled_video_frame_indices: Optional[list[int]]  # sampled video frame indices
    video_frames: Optional[torch.Tensor]  # video frames tensor

DataSet

Bases: TypedDict

DataSet is a dictionary that contains data items and meta information.

Source code in src/aeiva/common/types.py
25
26
27
28
29
class DataSet(TypedDict):
    r"""DataSet is a dictionary that contains data items and meta information.
    """
    data: list[DataItem]
    metadata: dict[str, Any]

ModelInput

Bases: TypedDict

ModelInput is a dictionary that contains all the information for a model input. We use it to construct LEGO style models.

Source code in src/aeiva/common/types.py
63
64
65
66
67
class ModelInput(TypedDict):
    r"""ModelInput is a dictionary that contains all the information for a model input.
    We use it to construct LEGO style models.
    """
    pass

ModelOutput

Bases: TypedDict

ModelOutput is a dictionary that contains all the information for a model output. We use it to construct LEGO style models.

Source code in src/aeiva/common/types.py
70
71
72
73
74
class ModelOutput(TypedDict):
    r"""ModelOutput is a dictionary that contains all the information for a model output.
    We use it to construct LEGO style models.
    """
    pass

TaskContext

Bases: TypedDict

TaskContext is a dictionary that contains all the information for a task.

Source code in src/aeiva/common/types.py
49
50
51
52
53
54
55
56
57
58
59
60
class TaskContext(TypedDict):
    r"""TaskContext is a dictionary that contains all the information for a task.
    """
    config_path: Optional[str]
    config: Optional[OmniConfig]
    dataloader: Optional[torch.utils.data.DataLoader]
    tokenizer: Optional[Any]
    model: Optional[Any]
    logger: Optional[Any]
    trainer: Optional[Any]
    current_model_input: Optional[DataItem]
    current_model_output: Optional[Any]

config

DataConfig dataclass

Bases: BaseConfig

This class contains the data configuration.

Source code in src/aeiva/config/general_configs.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@dataclass
class DataConfig(BaseConfig):
    """This class contains the data configuration."""
    dataset_path: Optional[str] = field(
        default=None, metadata={"help": "The path of the dataset to use."}
    )
    dataset_name: Optional[str] = field(
        default="customized", metadata={"help": "Should be \"customized\""}
    )
    is_custom_dataset: Optional[bool] = field(
        default=False, metadata={"help": "whether to use custom data"}
    )
    customized_cache_dir: Optional[str] = field(
        default=".cache/llm-ft/datasets",
        metadata={"help": "Where do you want to store the customized dataset caches"},
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=1e10,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    streaming: Optional[bool] = field(default=False, metadata={"help": "Enable streaming mode"})
    block_size: Optional[int] = field(
        default=512,
        metadata={
            "help": (
                "Optional input sequence length after tokenization. "
                "The training dataset will be truncated in block of this size for training. "
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
        },
    )
    overwrite_cache: Optional[bool] = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    group_texts_batch_size: Optional[int] = field(
        default=1000,
        metadata={
            "help": (
                "Number of samples that will be grouped together to go though"
                " `group_texts` operation. See `--disable_group_texts` for"
                " detailed explanation of this operation."
            )
        }
    )
    disable_group_texts: Optional[bool] = field(
        default=False,
        metadata={
            "help": (
                "Whether we group original samples together to generate sample"
                " sequences of length `block_size`. By default, we group every"
                " 1000 tokenized sequences together, divide them into "
                " [{total_num_tokens} / {block_size}] sequences, each with"
                " `block_size` tokens (the remaining tokens are ommited."
                " If this flag is set to True, we only group 1 tokenized"
                " sequence, i.e. cutting long sequence into chunks."
            )
        },
    )
    keep_linebreaks: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to keep line breaks when using TXT files or not."}
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "Evaluation File Path"},
    )

    def __post_init__(self):
        if self.streaming:
            require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")

        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."

ExplicitEnum

Bases: str, Enum

Enum with more explicit error message for missing values.

Source code in src/aeiva/config/general_configs.py
26
27
28
29
30
31
32
33
34
class ExplicitEnum(str, Enum):
    """
    Enum with more explicit error message for missing values.
    """
    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )

ModelConfig dataclass

Bases: BaseConfig

Model configuration class.

Source code in src/aeiva/config/general_configs.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
@dataclass
class ModelConfig(BaseConfig):
    """Model configuration class."""
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    lora_model_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The incremental model diff introduced by LoRA finetuning."
                " Along with the original non-finetuned model forms the whole"
                " finetuned model."
            )
        }
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    arch_type: Optional[str] = field(
        default="decoder_only",
        metadata={"help": "The architecture type of the model. Currently supported decoder_only or encoder_decoder"}
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    arch_type: Optional[str] = field(
        default="decoder_only",
        metadata={
            "help": (
                "Model architecture type, e.g. \"decoder_only\","
                " \"encoder_decoder\""
            ),
            "choices": ["decoder_only", "encoder_decoder", "text_regression", "vision_encoder_decoder"],
        },
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: Optional[str] = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: Optional[bool] = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    use_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to lora."},
    )
    lora_r: Optional[int] = field(
        default=8,
        metadata={"help": "the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has."},
    )
    lora_alpha: Optional[int] = field(
        default=32,
        metadata={"help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."},
    )
    lora_target_modules: Optional[list[str]] = field(
        default=None,
        metadata={"help": "Pretrained config name or path if not the same as model_name",
                              }
    )
    lora_dropout: Optional[float] = field(
        default=0.1,
        metadata={"help": "The dropout rate in lora.linear."},
    )
    save_aggregated_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to save aggregated lora."},
        )
    use_ram_optimized_load: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether use disk mapping when memory is not enough."}
    )
    use_flash_attention: Optional[bool] = field(
        default=False,
        metadata={
            "help": (
                "whether use flash attention layer to reduce GPU memory with"
                " higher time cost."
            )
        }
    )
    use_int8: Optional[bool] = field(
        default=False,
        metadata={"help": "whether to load int8 quantization for inference"}
    )
    custom_model: Optional[bool] = field(
        default=False,
        metadata={"help": "flag for the model from huggingface or not"}
    )
    # below is added for macaw model
    n_frames: Optional[int] = field(
        default=6,
        metadata={
            "help": "The number of frames for encoding a video."
        },
    )
    attention_heads: Optional[int] = field(
        default=220,
        metadata={
            "help": "The number of attention heads used in multi-head-attention."
        },
    )
    image_conv_kernel: Optional[int] = field(
        default=48,
        metadata={
            "help": "The size of the convolutional kernel for the image stream."
        },
    )
    image_conv_stride: Optional[int] = field(
        default=36,
        metadata={
            "help": "The stride of the convolutional kernel for the image stream."
        },
    )
    video_conv_kernel: Optional[int] = field(
        default=36,
        metadata={
            "help": "The size of the convolutional kernel for the video stream."
        },
    )
    video_conv_stride: Optional[int] = field(
        default=30,
        metadata={
            "help": "The stride of the convolutional kernel for the video stream."
        },
    )
    audio_conv_kernel: Optional[int] = field(
        default=240,
        metadata={
            "help": "The size of the convolutional kernel for the audio stream."
        },
    )
    audio_conv_stride: Optional[int] = field(
        default=220,
        metadata={
            "help": "The stride of the convolutional kernel for the audio stream."
        },
    )
    freeze_multi_modal_encoder: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to freeze the parameters of multi-modal encoders during training.)."
            )
        },
    )

    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )

OptimizerNames

Bases: ExplicitEnum

Stores the acceptable string identifiers for optimizers.

Source code in src/aeiva/config/general_configs.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class OptimizerNames(ExplicitEnum):
    """
    Stores the acceptable string identifiers for optimizers.
    """
    ADAMW_HF = "adamw_hf"
    ADAMW_TORCH = "adamw_torch"
    ADAMW_TORCH_FUSED = "adamw_torch_fused"
    ADAMW_TORCH_XLA = "adamw_torch_xla"
    ADAMW_APEX_FUSED = "adamw_apex_fused"
    ADAFACTOR = "adafactor"
    ADAMW_ANYPRECISION = "adamw_anyprecision"
    SGD = "sgd"
    ADAGRAD = "adagrad"
    ADAMW_BNB = "adamw_bnb_8bit"
    ADAMW_8BIT = "adamw_8bit"  # just an alias for adamw_bnb_8bit
    LION_8BIT = "lion_8bit"
    LION = "lion_32bit"
    PAGED_ADAMW = "paged_adamw_32bit"
    PAGED_ADAMW_8BIT = "paged_adamw_8bit"
    PAGED_LION = "paged_lion_32bit"
    PAGED_LION_8BIT = "paged_lion_8bit"

base_config

This module contains the base config classes.

We can define separate config classes for different modules, e.g., data, model, trainer, llm, etc. They will be automatically registered in the BaseConfig class.

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

BaseConfig dataclass

Base class for all configuration classes.

Source code in src/aeiva/config/base_config.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@dataclass
class BaseConfig:
    """
    Base class for all configuration classes.
    """
    subclasses = {}  # Dictionary to store subclasses

    def __init_subclass__(cls, **kwargs):
        """
        This method is called when a subclass is created.
        """
        super().__init_subclass__(**kwargs)
        BaseConfig.subclasses[cls.__name__] = cls

    def __post_init__(self):
        """
        Empty post-init to allow subclasses to call super().__post_init__().
        """
        pass

    @classmethod
    def from_dict(cls, data: dict):
        """
        Create a new instance of the class from a dictionary.
        """
        try:
            return cls(**data)
        except TypeError as e:
            invalid_keys = [key.strip("'") for key in re.findall(r"'(\w+)'", str(e))]
            raise ValueError(f"Invalid config keys provided: {invalid_keys}. Details: {e}")

    def to_dict(self):
        """
        Convert the instance to a dictionary.
        """
        return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}

    @classmethod
    def from_json(cls, json_path: str):
        """
        Create a new instance of the class from a JSON file.
        """
        with open(json_path, "r") as json_file:
            data = json.load(json_file)
        return cls.from_dict(data)

    def to_json(self, filepath: str):
        """
        Convert the instance to a JSON file.
        """
        with open(filepath, 'w') as json_file:
            json.dump(self.to_dict(), json_file, indent=4)

    @classmethod
    def from_yaml(cls, yaml_path: str):
        """
        Create a new instance of the class from a YAML file.
        """
        with open(yaml_path, "r") as yaml_file:
            data = yaml.safe_load(yaml_file)
        return cls.from_dict(data)

    def to_yaml(self, filepath: str):
        """
        Convert the instance to a YAML file.
        """
        with open(filepath, 'w') as yaml_file:
            yaml.dump(self.to_dict(), yaml_file)

    @classmethod
    def from_json_or_yaml(cls, file_path: str):
        """
        Create a new instance of the class from a JSON or YAML file.
        """
        _, file_extension = os.path.splitext(file_path)
        if file_extension == ".json":
            return cls.from_json(file_path)
        elif file_extension == ".yaml" or file_extension == ".yml":
            return cls.from_yaml(file_path)
        else:
            raise ValueError(f"Unsupported file extension: {file_extension}. Please use .json or .yaml")

    def __str__(self):
        """
        Return a string representation of the instance.
        """
        return pprint.pformat(self.to_dict(), indent=4)
__init_subclass__(**kwargs)

This method is called when a subclass is created.

Source code in src/aeiva/config/base_config.py
28
29
30
31
32
33
def __init_subclass__(cls, **kwargs):
    """
    This method is called when a subclass is created.
    """
    super().__init_subclass__(**kwargs)
    BaseConfig.subclasses[cls.__name__] = cls
__post_init__()

Empty post-init to allow subclasses to call super().post_init().

Source code in src/aeiva/config/base_config.py
35
36
37
38
39
def __post_init__(self):
    """
    Empty post-init to allow subclasses to call super().__post_init__().
    """
    pass
__str__()

Return a string representation of the instance.

Source code in src/aeiva/config/base_config.py
103
104
105
106
107
def __str__(self):
    """
    Return a string representation of the instance.
    """
    return pprint.pformat(self.to_dict(), indent=4)
from_dict(data) classmethod

Create a new instance of the class from a dictionary.

Source code in src/aeiva/config/base_config.py
41
42
43
44
45
46
47
48
49
50
@classmethod
def from_dict(cls, data: dict):
    """
    Create a new instance of the class from a dictionary.
    """
    try:
        return cls(**data)
    except TypeError as e:
        invalid_keys = [key.strip("'") for key in re.findall(r"'(\w+)'", str(e))]
        raise ValueError(f"Invalid config keys provided: {invalid_keys}. Details: {e}")
from_json(json_path) classmethod

Create a new instance of the class from a JSON file.

Source code in src/aeiva/config/base_config.py
58
59
60
61
62
63
64
65
@classmethod
def from_json(cls, json_path: str):
    """
    Create a new instance of the class from a JSON file.
    """
    with open(json_path, "r") as json_file:
        data = json.load(json_file)
    return cls.from_dict(data)
from_json_or_yaml(file_path) classmethod

Create a new instance of the class from a JSON or YAML file.

Source code in src/aeiva/config/base_config.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@classmethod
def from_json_or_yaml(cls, file_path: str):
    """
    Create a new instance of the class from a JSON or YAML file.
    """
    _, file_extension = os.path.splitext(file_path)
    if file_extension == ".json":
        return cls.from_json(file_path)
    elif file_extension == ".yaml" or file_extension == ".yml":
        return cls.from_yaml(file_path)
    else:
        raise ValueError(f"Unsupported file extension: {file_extension}. Please use .json or .yaml")
from_yaml(yaml_path) classmethod

Create a new instance of the class from a YAML file.

Source code in src/aeiva/config/base_config.py
74
75
76
77
78
79
80
81
@classmethod
def from_yaml(cls, yaml_path: str):
    """
    Create a new instance of the class from a YAML file.
    """
    with open(yaml_path, "r") as yaml_file:
        data = yaml.safe_load(yaml_file)
    return cls.from_dict(data)
to_dict()

Convert the instance to a dictionary.

Source code in src/aeiva/config/base_config.py
52
53
54
55
56
def to_dict(self):
    """
    Convert the instance to a dictionary.
    """
    return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
to_json(filepath)

Convert the instance to a JSON file.

Source code in src/aeiva/config/base_config.py
67
68
69
70
71
72
def to_json(self, filepath: str):
    """
    Convert the instance to a JSON file.
    """
    with open(filepath, 'w') as json_file:
        json.dump(self.to_dict(), json_file, indent=4)
to_yaml(filepath)

Convert the instance to a YAML file.

Source code in src/aeiva/config/base_config.py
83
84
85
86
87
88
def to_yaml(self, filepath: str):
    """
    Convert the instance to a YAML file.
    """
    with open(filepath, 'w') as yaml_file:
        yaml.dump(self.to_dict(), yaml_file)

custom_configs

macaw_config

This module contains the config for macaw model.

We can define separate config classes for different specific models/datasets/tasks.

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

MacawConfig dataclass

Bases: BaseConfig

Define user-customized config here.

Source code in src/aeiva/config/custom_configs/macaw_config.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@dataclass
class MacawConfig(BaseConfig):
    """
    Define user-customized config here.
    """
    image_dir: Optional[str] = field(
        default=None,
        metadata={"help": "The directory of image data"}
    )
    video_dir: Optional[str] = field(
        default=None,
        metadata={"help": "The directory of video data"}
    )
    frame_dir: Optional[str] = field(
        default=None,
        metadata={"help": "The directory to save video frames"}
    )
    audio_dir: Optional[str] = field(
        default=None,
        metadata={"help": "The directory to save video audios"}
    )
    num_frames_to_sample: Optional[int] = field(
        default=120,
        metadata={"help": "The number of frames to sample from a video"}
    )
    num_frames_to_load: Optional[int] = field(
        default=6,
        metadata={"help": "The number of frames to load as a part of model inputs"}
    )
    num_samples_per_dataset: Optional[int] = field(
        default=100,
        metadata={"help": "The number of samples to load from each dataset"}
    )
    num_samples_per_merged_dataset: Optional[int] = field(
        default=20,
        metadata={"help": "The number of samples to save after merging datasets"}
    )
    batch_size: Optional[int] = field(
        default=1,
        metadata={"help": "The batch size of model inputs"}
    )
    max_seq_len_for_preprocess: Optional[int] = field(
        default=256,
        metadata={"help": "The maximum sequence length for preprocess"}
    )
    run_time_cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "The directory to save running time data, such as video frames, audios, and so on."}
    )
    tokenizer_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "The name or path of tokenizer"}
    )
    clip_model_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "The name or path of clip model"}
    )
    whisper_model_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "The name or path of whisper model"}
    )
    llama7b_model_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "The name or path of llama7b model"}
    )
    macaw_model_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "The name or path of macaw model"}
    )
    mode: Optional[str] = field(
        default="train",
        metadata={"help": "The mode of train, eval, or inference"}
    )
    model_name: Optional[str] = field(
        default="macaw",
        metadata={"help": "The name of model"}
    )
    resource_ready: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether the pre-requisite resource is ready, e.g., download pretrained models and datasets"}
    )

general_configs

This module contains some general config classes that can be used in deep learning projects.

E.g., data config, model config, trainer config, etc.

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

DataConfig dataclass

Bases: BaseConfig

This class contains the data configuration.

Source code in src/aeiva/config/general_configs.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@dataclass
class DataConfig(BaseConfig):
    """This class contains the data configuration."""
    dataset_path: Optional[str] = field(
        default=None, metadata={"help": "The path of the dataset to use."}
    )
    dataset_name: Optional[str] = field(
        default="customized", metadata={"help": "Should be \"customized\""}
    )
    is_custom_dataset: Optional[bool] = field(
        default=False, metadata={"help": "whether to use custom data"}
    )
    customized_cache_dir: Optional[str] = field(
        default=".cache/llm-ft/datasets",
        metadata={"help": "Where do you want to store the customized dataset caches"},
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=1e10,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    streaming: Optional[bool] = field(default=False, metadata={"help": "Enable streaming mode"})
    block_size: Optional[int] = field(
        default=512,
        metadata={
            "help": (
                "Optional input sequence length after tokenization. "
                "The training dataset will be truncated in block of this size for training. "
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
        },
    )
    overwrite_cache: Optional[bool] = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    group_texts_batch_size: Optional[int] = field(
        default=1000,
        metadata={
            "help": (
                "Number of samples that will be grouped together to go though"
                " `group_texts` operation. See `--disable_group_texts` for"
                " detailed explanation of this operation."
            )
        }
    )
    disable_group_texts: Optional[bool] = field(
        default=False,
        metadata={
            "help": (
                "Whether we group original samples together to generate sample"
                " sequences of length `block_size`. By default, we group every"
                " 1000 tokenized sequences together, divide them into "
                " [{total_num_tokens} / {block_size}] sequences, each with"
                " `block_size` tokens (the remaining tokens are ommited."
                " If this flag is set to True, we only group 1 tokenized"
                " sequence, i.e. cutting long sequence into chunks."
            )
        },
    )
    keep_linebreaks: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to keep line breaks when using TXT files or not."}
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "Evaluation File Path"},
    )

    def __post_init__(self):
        if self.streaming:
            require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")

        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."

ExplicitEnum

Bases: str, Enum

Enum with more explicit error message for missing values.

Source code in src/aeiva/config/general_configs.py
26
27
28
29
30
31
32
33
34
class ExplicitEnum(str, Enum):
    """
    Enum with more explicit error message for missing values.
    """
    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )

ModelConfig dataclass

Bases: BaseConfig

Model configuration class.

Source code in src/aeiva/config/general_configs.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
@dataclass
class ModelConfig(BaseConfig):
    """Model configuration class."""
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    lora_model_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The incremental model diff introduced by LoRA finetuning."
                " Along with the original non-finetuned model forms the whole"
                " finetuned model."
            )
        }
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    arch_type: Optional[str] = field(
        default="decoder_only",
        metadata={"help": "The architecture type of the model. Currently supported decoder_only or encoder_decoder"}
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    arch_type: Optional[str] = field(
        default="decoder_only",
        metadata={
            "help": (
                "Model architecture type, e.g. \"decoder_only\","
                " \"encoder_decoder\""
            ),
            "choices": ["decoder_only", "encoder_decoder", "text_regression", "vision_encoder_decoder"],
        },
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: Optional[str] = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: Optional[bool] = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    use_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to lora."},
    )
    lora_r: Optional[int] = field(
        default=8,
        metadata={"help": "the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has."},
    )
    lora_alpha: Optional[int] = field(
        default=32,
        metadata={"help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."},
    )
    lora_target_modules: Optional[list[str]] = field(
        default=None,
        metadata={"help": "Pretrained config name or path if not the same as model_name",
                              }
    )
    lora_dropout: Optional[float] = field(
        default=0.1,
        metadata={"help": "The dropout rate in lora.linear."},
    )
    save_aggregated_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to save aggregated lora."},
        )
    use_ram_optimized_load: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether use disk mapping when memory is not enough."}
    )
    use_flash_attention: Optional[bool] = field(
        default=False,
        metadata={
            "help": (
                "whether use flash attention layer to reduce GPU memory with"
                " higher time cost."
            )
        }
    )
    use_int8: Optional[bool] = field(
        default=False,
        metadata={"help": "whether to load int8 quantization for inference"}
    )
    custom_model: Optional[bool] = field(
        default=False,
        metadata={"help": "flag for the model from huggingface or not"}
    )
    # below is added for macaw model
    n_frames: Optional[int] = field(
        default=6,
        metadata={
            "help": "The number of frames for encoding a video."
        },
    )
    attention_heads: Optional[int] = field(
        default=220,
        metadata={
            "help": "The number of attention heads used in multi-head-attention."
        },
    )
    image_conv_kernel: Optional[int] = field(
        default=48,
        metadata={
            "help": "The size of the convolutional kernel for the image stream."
        },
    )
    image_conv_stride: Optional[int] = field(
        default=36,
        metadata={
            "help": "The stride of the convolutional kernel for the image stream."
        },
    )
    video_conv_kernel: Optional[int] = field(
        default=36,
        metadata={
            "help": "The size of the convolutional kernel for the video stream."
        },
    )
    video_conv_stride: Optional[int] = field(
        default=30,
        metadata={
            "help": "The stride of the convolutional kernel for the video stream."
        },
    )
    audio_conv_kernel: Optional[int] = field(
        default=240,
        metadata={
            "help": "The size of the convolutional kernel for the audio stream."
        },
    )
    audio_conv_stride: Optional[int] = field(
        default=220,
        metadata={
            "help": "The stride of the convolutional kernel for the audio stream."
        },
    )
    freeze_multi_modal_encoder: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to freeze the parameters of multi-modal encoders during training.)."
            )
        },
    )

    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )

OptimizerNames

Bases: ExplicitEnum

Stores the acceptable string identifiers for optimizers.

Source code in src/aeiva/config/general_configs.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class OptimizerNames(ExplicitEnum):
    """
    Stores the acceptable string identifiers for optimizers.
    """
    ADAMW_HF = "adamw_hf"
    ADAMW_TORCH = "adamw_torch"
    ADAMW_TORCH_FUSED = "adamw_torch_fused"
    ADAMW_TORCH_XLA = "adamw_torch_xla"
    ADAMW_APEX_FUSED = "adamw_apex_fused"
    ADAFACTOR = "adafactor"
    ADAMW_ANYPRECISION = "adamw_anyprecision"
    SGD = "sgd"
    ADAGRAD = "adagrad"
    ADAMW_BNB = "adamw_bnb_8bit"
    ADAMW_8BIT = "adamw_8bit"  # just an alias for adamw_bnb_8bit
    LION_8BIT = "lion_8bit"
    LION = "lion_32bit"
    PAGED_ADAMW = "paged_adamw_32bit"
    PAGED_ADAMW_8BIT = "paged_adamw_8bit"
    PAGED_LION = "paged_lion_32bit"
    PAGED_LION_8BIT = "paged_lion_8bit"

omni_config

This module contains the OmniConfig classes.

We can define separate config classes for different modules, e.g., data, model, trainer, etc. The OmniConfig class is the combination of all config classes. It can also accept command line arguments to update the config values.

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

OmniConfig dataclass

Bases: BaseConfig

Source code in src/aeiva/config/omni_config.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@dataclass
class OmniConfig(BaseConfig):
    @staticmethod
    def create_omni_config():
        """
        Initializes OmniConfig by aggregating all configuration classes.
        """
        # Aggregating default values from all config classes
        defaults = {}
        for config_class_name, config_class in BaseConfig.subclasses.items():
            if config_class_name == "OmniConfig":
                continue
            for field_name, field_obj in config_class.__dataclass_fields__.items():
                if field_name in defaults:
                    raise ValueError(f"Overlapping config argument: '{field_name}' found in {config_class.__name__}")
                default_value = getattr(config_class(), field_name, None)
                defaults[field_name] = default_value

        def __init__(self, **kwargs):
            for key, default_value in defaults.items():
                setattr(self, key, kwargs.get(key, default_value))

        OmniConfig.__init__ = __init__
        return OmniConfig

    def update_from_args(self, namespace_args: argparse.Namespace):
        """
        Updates the configuration based on parsed command-line arguments.
        """
        for key, value in vars(namespace_args).items():
            if hasattr(self, key) and value is not None:
                setattr(self, key, value)

    def get_argparse_parser(self):
        """
        Creates an argument parser that can handle complex types.
        """
        parser = argparse.ArgumentParser()
        for config_class_name, config_class in BaseConfig.subclasses.items():
            if config_class_name == "OmniConfig":
                continue
            for field_name, field_obj in config_class.__dataclass_fields__.items():
                field_type = field_obj.type

                # Handle Optional types
                if get_origin(field_type) is Union and type(None) in get_args(field_type):
                    field_type = next(arg for arg in get_args(field_type) if arg is not type(None))

                arg_name = '--' + field_name
                help_msg = field_obj.metadata.get("help", f"{field_name} ({field_type})")

                origin = get_origin(field_type)
                args = get_args(field_type)

                # Handle Enums
                if isinstance(field_type, type) and issubclass(field_type, enum.Enum):
                    choices = [item.value for item in field_type]
                    parser.add_argument(arg_name, type=str, choices=choices, help=help_msg)
                    continue

                # Handle list types
                if origin is list:
                    item_type = args[0]
                    if item_type is str:
                        parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)
                    elif item_type is int:
                        parser.add_argument(arg_name, nargs='+', type=int, help=help_msg)
                    else:
                        # Default to strings if item type is not specifically handled
                        parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)
                    continue

                # Handle tuple types
                if origin is tuple:
                    # Accept comma-separated values and convert to tuple
                    def tuple_type(s):
                        try:
                            return tuple(map(int, s.split(',')))
                        except ValueError:
                            raise argparse.ArgumentTypeError("Tuples must be comma-separated integers.")

                    parser.add_argument(arg_name, type=tuple_type, help=help_msg)
                    continue

                # Handle dict types
                if origin is dict:
                    # Expect JSON string
                    def dict_type(s):
                        try:
                            return json.loads(s)
                        except json.JSONDecodeError:
                            raise argparse.ArgumentTypeError("Dictionaries must be valid JSON strings.")

                    parser.add_argument(arg_name, type=dict_type, help=help_msg)
                    continue

                # Handle basic types
                if field_type is int:
                    parser.add_argument(arg_name, type=int, help=help_msg)
                elif field_type is float:
                    parser.add_argument(arg_name, type=float, help=help_msg)
                elif field_type is str:
                    parser.add_argument(arg_name, type=str, help=help_msg)
                elif field_type is bool:
                    parser.add_argument(arg_name, action='store_true', help=help_msg)
                else:
                    print(f"Warning: unsupported type {field_type} for field '{field_name}'")
        return parser
create_omni_config() staticmethod

Initializes OmniConfig by aggregating all configuration classes.

Source code in src/aeiva/config/omni_config.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@staticmethod
def create_omni_config():
    """
    Initializes OmniConfig by aggregating all configuration classes.
    """
    # Aggregating default values from all config classes
    defaults = {}
    for config_class_name, config_class in BaseConfig.subclasses.items():
        if config_class_name == "OmniConfig":
            continue
        for field_name, field_obj in config_class.__dataclass_fields__.items():
            if field_name in defaults:
                raise ValueError(f"Overlapping config argument: '{field_name}' found in {config_class.__name__}")
            default_value = getattr(config_class(), field_name, None)
            defaults[field_name] = default_value

    def __init__(self, **kwargs):
        for key, default_value in defaults.items():
            setattr(self, key, kwargs.get(key, default_value))

    OmniConfig.__init__ = __init__
    return OmniConfig
get_argparse_parser()

Creates an argument parser that can handle complex types.

Source code in src/aeiva/config/omni_config.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def get_argparse_parser(self):
    """
    Creates an argument parser that can handle complex types.
    """
    parser = argparse.ArgumentParser()
    for config_class_name, config_class in BaseConfig.subclasses.items():
        if config_class_name == "OmniConfig":
            continue
        for field_name, field_obj in config_class.__dataclass_fields__.items():
            field_type = field_obj.type

            # Handle Optional types
            if get_origin(field_type) is Union and type(None) in get_args(field_type):
                field_type = next(arg for arg in get_args(field_type) if arg is not type(None))

            arg_name = '--' + field_name
            help_msg = field_obj.metadata.get("help", f"{field_name} ({field_type})")

            origin = get_origin(field_type)
            args = get_args(field_type)

            # Handle Enums
            if isinstance(field_type, type) and issubclass(field_type, enum.Enum):
                choices = [item.value for item in field_type]
                parser.add_argument(arg_name, type=str, choices=choices, help=help_msg)
                continue

            # Handle list types
            if origin is list:
                item_type = args[0]
                if item_type is str:
                    parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)
                elif item_type is int:
                    parser.add_argument(arg_name, nargs='+', type=int, help=help_msg)
                else:
                    # Default to strings if item type is not specifically handled
                    parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)
                continue

            # Handle tuple types
            if origin is tuple:
                # Accept comma-separated values and convert to tuple
                def tuple_type(s):
                    try:
                        return tuple(map(int, s.split(',')))
                    except ValueError:
                        raise argparse.ArgumentTypeError("Tuples must be comma-separated integers.")

                parser.add_argument(arg_name, type=tuple_type, help=help_msg)
                continue

            # Handle dict types
            if origin is dict:
                # Expect JSON string
                def dict_type(s):
                    try:
                        return json.loads(s)
                    except json.JSONDecodeError:
                        raise argparse.ArgumentTypeError("Dictionaries must be valid JSON strings.")

                parser.add_argument(arg_name, type=dict_type, help=help_msg)
                continue

            # Handle basic types
            if field_type is int:
                parser.add_argument(arg_name, type=int, help=help_msg)
            elif field_type is float:
                parser.add_argument(arg_name, type=float, help=help_msg)
            elif field_type is str:
                parser.add_argument(arg_name, type=str, help=help_msg)
            elif field_type is bool:
                parser.add_argument(arg_name, action='store_true', help=help_msg)
            else:
                print(f"Warning: unsupported type {field_type} for field '{field_name}'")
    return parser
update_from_args(namespace_args)

Updates the configuration based on parsed command-line arguments.

Source code in src/aeiva/config/omni_config.py
50
51
52
53
54
55
56
def update_from_args(self, namespace_args: argparse.Namespace):
    """
    Updates the configuration based on parsed command-line arguments.
    """
    for key, value in vars(namespace_args).items():
        if hasattr(self, key) and value is not None:
            setattr(self, key, value)

data

processor

This module contains the data processor.

@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

process_dataset(formatted_dataset, pipeline, output_dir, dataset_name='')

Process a dataset with a pipeline of functions.

Parameters:

Name Type Description Default
formatted_dataset DataSet

the dataset to be processed.

required
pipeline list[Callable]

a list of functions to be applied to the dataset.

required
output_dir Optional[str]

the output directory to save the processed dataset.

required
dataset_name Optional[str]

the name of the dataset. Defaults to "".

''

Returns:

Name Type Description
DataSet DataSet

the processed dataset.

Source code in src/aeiva/data/processor.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def process_dataset(formatted_dataset: DataSet,
                    pipeline: list[Callable],
                    output_dir: Optional[str],
                    dataset_name: Optional[str] = "") -> DataSet:
    """
    Process a dataset with a pipeline of functions.

    Args:
        formatted_dataset (DataSet): the dataset to be processed.
        pipeline (list[Callable]): a list of functions to be applied to the dataset.
        output_dir (Optional[str]): the output directory to save the processed dataset.
        dataset_name (Optional[str], optional): the name of the dataset. Defaults to "".

    Returns:
        DataSet: the processed dataset.
    """
    processed_data = []
    pipeline = Pipeline(pipeline)
    for item in formatted_dataset["data"]:
        processed_data.append(pipeline(item.copy()))

    output = {"data": processed_data, "metadata": formatted_dataset["metadata"]}
    if output_dir is not None:
        ensure_dir(output_dir)
        dump_json(output, f"{output_dir}/{dataset_name}_dataset.processed.json")
    return output

demo

chat_gradio

bot(user_input, history) async

Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.

Source code in src/aeiva/demo/chat_gradio.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
async def bot(user_input, history):
    """
    Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.
    """
    if agent is None:
        logger.error("Agent is not initialized.")
        history.append({"role": "assistant", "content": "Agent is not initialized."})
        yield history, ''
        return

    try:
        # Append user's message to history
        history.append({"role": "user", "content": user_input})
        # Append an empty assistant response
        history.append({"role": "assistant", "content": ""})
        yield history, ''  # Display the user's message
        logger.info(f"User input appended to history: {user_input}")

        stream = config_dict["llm_gateway_config"]["llm_stream"]
        use_async = config_dict["llm_gateway_config"]["llm_use_async"]

        # Emit the 'perception.gradio' event with stream=True
        emit_future = asyncio.run_coroutine_threadsafe(
            agent.event_bus.emit('perception.gradio', payload=user_input),  # TODO: maybe simplify payload, Agent can directly read stream and use_async from config.
            agent.event_bus.loop
        )
        emit_future.result()  # Ensure the event is emitted
        logger.info(f"Emitted 'perception.gradio' event with payload: {user_input} | Stream: {stream}")

        assistant_message = ''
        if stream:
            while True:
                try:
                    # Non-blocking response retrieval from the thread-safe queue with timeout
                    response = await asyncio.wait_for(
                        asyncio.to_thread(response_queue.get, True, 30),
                        timeout=30
                    )
                    logger.info(f"Retrieved response from queue: {response}")
                    if response == "<END_OF_RESPONSE>":
                        logger.info("Received end of response signal.")
                        break
                    assistant_message += response
                    # Create a new history list to ensure Gradio detects the update
                    new_history = history.copy()
                    new_history[-1]["content"] = assistant_message
                    logger.info(f"Yielding updated history: {new_history}")
                    yield new_history, ''
                except asyncio.TimeoutError:
                    logger.warning("Timeout: No response received from Agent.")
                    # Create a new history list to ensure Gradio detects the update
                    new_history = history.copy()
                    new_history[-1]["content"] = "I'm sorry, I didn't receive a response in time."
                    yield new_history, ''
                    break
        else:
            try:
                # Non-blocking response retrieval from the thread-safe queue with timeout
                response = await asyncio.wait_for(
                    asyncio.to_thread(response_queue.get, True, 30),
                    timeout=30
                )
                logger.info(f"Retrieved response from queue: {response}")
                assistant_message += response
                # Create a new history list to ensure Gradio detects the update
                new_history = history.copy()
                new_history[-1]["content"] = assistant_message
                logger.info(f"Yielding updated history: {new_history}")
                yield new_history, ''
            except asyncio.TimeoutError:
                logger.warning("Timeout: No response received from Agent.")
                # Create a new history list to ensure Gradio detects the update
                new_history = history.copy()
                new_history[-1]["content"] = "I'm sorry, I didn't receive a response in time."
                yield new_history, ''

    except Exception as e:
        logger.error(f"Unexpected Error in bot function: {e}")
        # Create a new history list to ensure Gradio detects the update
        new_history = history.copy()
        new_history[-1]["content"] = "An unexpected error occurred."
        yield new_history, ''

clear_media()

Clears the uploaded media paths.

Source code in src/aeiva/demo/chat_gradio.py
136
137
138
139
140
141
142
def clear_media():
    """
    Clears the uploaded media paths.
    """
    # Implement any necessary logic to clear media paths or data
    logger.info("Cleared uploaded media paths.")
    return ""

handle_upload(file)

Handles file uploads and delegates to specific handlers based on file type.

Parameters:

Name Type Description Default
file

Uploaded file object.

required

Returns:

Name Type Description
str

Message indicating the upload status.

Source code in src/aeiva/demo/chat_gradio.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def handle_upload(file):
    """
    Handles file uploads and delegates to specific handlers based on file type.

    Args:
        file: Uploaded file object.

    Returns:
        str: Message indicating the upload status.
    """
    if file is None:
        return ""
    if file.type.startswith("image"):
        return handle_image_upload(file)
    elif file.type.startswith("video"):
        return handle_video_upload(file)
    elif file.type.startswith("audio"):
        return handle_audio_upload(file)
    else:
        logger.warning(f"Unsupported file type uploaded: {file.type}")
        return "Unsupported file type uploaded."

mm_chatbot

This module defines a multimodal chatbot demo with gradio.

@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-13

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

environment

environment

Environment

Bases: ABC

Abstract base class for an environment in which an intelligent agent operates.

Each environment provides context, defines interactions, and manages its own state. Subclasses should implement specific methods for different types of environments.

Attributes:

Name Type Description
config EnvironmentConfig

Configuration settings for the environment.

state Any

Current state of the environment, initialized from the config.

entities List[Any]

Entities present within the environment.

constraints Dict[str, Any]

Rules or limitations for interactions in the environment.

time Optional[int]

Time progression within the environment, if enabled.

Source code in src/aeiva/environment/environment.py
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class Environment(ABC):
    """
    Abstract base class for an environment in which an intelligent agent operates.

    Each environment provides context, defines interactions, and manages its own state.
    Subclasses should implement specific methods for different types of environments.

    Attributes:
        config (EnvironmentConfig): Configuration settings for the environment.
        state (Any): Current state of the environment, initialized from the config.
        entities (List[Any]): Entities present within the environment.
        constraints (Dict[str, Any]): Rules or limitations for interactions in the environment.
        time (Optional[int]): Time progression within the environment, if enabled.
    """

    def __init__(self, config: EnvironmentConfig):
        """
        Initialize the environment with a given configuration.

        Args:
            config (EnvironmentConfig): Configuration settings for the environment.
        """
        self.config = config
        self.state = config.initial_state
        self.entities = config.entities
        self.constraints = config.constraints
        self.time = 0 if config.time_enabled else None
        self.setup()

    @abstractmethod
    def setup(self):
        """
        Set up the environment based on its configuration.
        Subclasses should define any initialization logic here.
        """
        pass

    @abstractmethod
    def reset(self):
        """
        Reset the environment to its initial state as defined by the configuration.
        """
        self.state = self.config.initial_state
        self.time = 0 if self.config.time_enabled else None

    @abstractmethod
    def step(self, actions: Dict[Any, Any]):
        """
        Advance the environment by one step based on actions taken by agents.

        Args:
            actions (Dict[Any, Any]): A dictionary of actions performed by agents.
        """
        pass

    @abstractmethod
    def observe(self, agent: Any) -> Any:
        """
        Provide observations to an agent based on the current state.

        Args:
            agent (Any): The agent requesting observation.

        Returns:
            Any: Observation data formatted according to the agent's perception capabilities.
        """
        pass

    @abstractmethod
    def act(self, action: Any, target: Optional[Any] = None):
        """
        Execute an action in the environment, potentially modifying its state.

        Args:
            action (Any): The action to be executed.
            target (Optional[Any]): Target entity for the action, if applicable.
        """
        pass

    def render(self):
        """
        Visualize or output the environment's current state. Optional for subclasses.
        """
        print(f"Environment State: {self.state}")

    def get_context(self) -> Any:
        """
        Retrieve relevant context information from the environment, useful for agent processing.

        Returns:
            Any: Contextual data or state relevant to the agent's tasks.
        """
        return self.state

    def close(self):
        """
        Clean up any resources tied to the environment when it's no longer needed.
        """
        print("Closing environment and releasing resources.")

    def __repr__(self) -> str:
        return (f"Environment(type={self.config.environment_type}, "
                f"state={self.state}, "
                f"entities={self.entities}, "
                f"time={self.time}, "
                f"constraints={self.constraints})")
__init__(config)

Initialize the environment with a given configuration.

Parameters:

Name Type Description Default
config EnvironmentConfig

Configuration settings for the environment.

required
Source code in src/aeiva/environment/environment.py
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(self, config: EnvironmentConfig):
    """
    Initialize the environment with a given configuration.

    Args:
        config (EnvironmentConfig): Configuration settings for the environment.
    """
    self.config = config
    self.state = config.initial_state
    self.entities = config.entities
    self.constraints = config.constraints
    self.time = 0 if config.time_enabled else None
    self.setup()
act(action, target=None) abstractmethod

Execute an action in the environment, potentially modifying its state.

Parameters:

Name Type Description Default
action Any

The action to be executed.

required
target Optional[Any]

Target entity for the action, if applicable.

None
Source code in src/aeiva/environment/environment.py
73
74
75
76
77
78
79
80
81
82
@abstractmethod
def act(self, action: Any, target: Optional[Any] = None):
    """
    Execute an action in the environment, potentially modifying its state.

    Args:
        action (Any): The action to be executed.
        target (Optional[Any]): Target entity for the action, if applicable.
    """
    pass
close()

Clean up any resources tied to the environment when it's no longer needed.

Source code in src/aeiva/environment/environment.py
 99
100
101
102
103
def close(self):
    """
    Clean up any resources tied to the environment when it's no longer needed.
    """
    print("Closing environment and releasing resources.")
get_context()

Retrieve relevant context information from the environment, useful for agent processing.

Returns:

Name Type Description
Any Any

Contextual data or state relevant to the agent's tasks.

Source code in src/aeiva/environment/environment.py
90
91
92
93
94
95
96
97
def get_context(self) -> Any:
    """
    Retrieve relevant context information from the environment, useful for agent processing.

    Returns:
        Any: Contextual data or state relevant to the agent's tasks.
    """
    return self.state
observe(agent) abstractmethod

Provide observations to an agent based on the current state.

Parameters:

Name Type Description Default
agent Any

The agent requesting observation.

required

Returns:

Name Type Description
Any Any

Observation data formatted according to the agent's perception capabilities.

Source code in src/aeiva/environment/environment.py
60
61
62
63
64
65
66
67
68
69
70
71
@abstractmethod
def observe(self, agent: Any) -> Any:
    """
    Provide observations to an agent based on the current state.

    Args:
        agent (Any): The agent requesting observation.

    Returns:
        Any: Observation data formatted according to the agent's perception capabilities.
    """
    pass
render()

Visualize or output the environment's current state. Optional for subclasses.

Source code in src/aeiva/environment/environment.py
84
85
86
87
88
def render(self):
    """
    Visualize or output the environment's current state. Optional for subclasses.
    """
    print(f"Environment State: {self.state}")
reset() abstractmethod

Reset the environment to its initial state as defined by the configuration.

Source code in src/aeiva/environment/environment.py
42
43
44
45
46
47
48
@abstractmethod
def reset(self):
    """
    Reset the environment to its initial state as defined by the configuration.
    """
    self.state = self.config.initial_state
    self.time = 0 if self.config.time_enabled else None
setup() abstractmethod

Set up the environment based on its configuration. Subclasses should define any initialization logic here.

Source code in src/aeiva/environment/environment.py
34
35
36
37
38
39
40
@abstractmethod
def setup(self):
    """
    Set up the environment based on its configuration.
    Subclasses should define any initialization logic here.
    """
    pass
step(actions) abstractmethod

Advance the environment by one step based on actions taken by agents.

Parameters:

Name Type Description Default
actions Dict[Any, Any]

A dictionary of actions performed by agents.

required
Source code in src/aeiva/environment/environment.py
50
51
52
53
54
55
56
57
58
@abstractmethod
def step(self, actions: Dict[Any, Any]):
    """
    Advance the environment by one step based on actions taken by agents.

    Args:
        actions (Dict[Any, Any]): A dictionary of actions performed by agents.
    """
    pass

environment_config

EnvironmentConfig dataclass

Bases: BaseConfig

Configuration class for initializing an environment.

Attributes:

Name Type Description
environment_type str

Type of the environment, see EnvironmentType class.

initial_state Optional[Any]

Optional initial state of the environment.

constraints Dict[str, Any]

Rules or constraints governing the environment.

entities List[Any]

Entities present within the environment.

time_enabled bool

Whether the environment tracks time progression.

Source code in src/aeiva/environment/environment_config.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@dataclass
class EnvironmentConfig(BaseConfig):
    """
    Configuration class for initializing an environment.

    Attributes:
        environment_type (str): Type of the environment, see EnvironmentType class.
        initial_state (Optional[Any]): Optional initial state of the environment.
        constraints (Dict[str, Any]): Rules or constraints governing the environment.
        entities (List[Any]): Entities present within the environment.
        time_enabled (bool): Whether the environment tracks time progression.
    """

    environment_type: str = field(
        metadata={"help": "Type of the environment (e.g., 'user', 'document', 'game')."}
    )
    initial_state: Optional[Any] = field(
        default=None,
        metadata={"help": "Optional initial state of the environment."}
    )
    constraints: Dict[str, Any] = field(
        default_factory=dict,
        metadata={"help": "Rules or constraints for the environment."}
    )
    entities: List[Any] = field(
        default_factory=list,
        metadata={"help": "Entities within the environment."}
    )
    time_enabled: bool = field(
        default=False,
        metadata={"help": "Flag to enable time progression."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Perform any necessary validation
        if not self.environment_type:
            raise ValueError("Environment type must be provided.")

environment_type

EnvironmentType

A class to hold constants for various environment types, organized by broad categories to maximize generality while supporting diverse use cases.

Categories
  • Interaction-Based: Environments with user or agent interaction.
  • Digital: Environments involving digital interfaces, applications, or software systems.
  • Data-Based: Static or dynamic data collections or document repositories.
  • Virtual/Simulated: Simulated, spatial, or immersive virtual environments.
  • World-Level: Comprehensive real or virtual world environments.
Source code in src/aeiva/environment/environment_type.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class EnvironmentType:
    """
    A class to hold constants for various environment types, organized by broad categories
    to maximize generality while supporting diverse use cases.

    Categories:
        - Interaction-Based: Environments with user or agent interaction.
        - Digital: Environments involving digital interfaces, applications, or software systems.
        - Data-Based: Static or dynamic data collections or document repositories.
        - Virtual/Simulated: Simulated, spatial, or immersive virtual environments.
        - World-Level: Comprehensive real or virtual world environments.
    """

    # Interaction-Based Environments
    INTERACTIVE = "Interactive"  # Environments involving user or multi-agent interaction.

    # Digital Environments
    DIGITAL_ENVIRONMENT = "Digital Environment"  # Digital workspaces, applications, OS, or software systems.

    # Data-Based Environments
    DATA_REPOSITORY = "Data Repository"  # Static datasets, dynamic data streams, or document repositories (e.g., knowledge bases).

    # Virtual/Simulated Environments
    VIRTUAL_ENVIRONMENT = "Virtual Environment"  # Simulated or immersive 3D spaces, including games and VR.

    # World-Level Environments
    FULL_WORLD = "Full World"  # Comprehensive virtual or real-world environment.

    # Meta/Complex Environments
    HYBRID_ENVIRONMENT = "Hybrid Environment"  # Combination of multiple types.

    # Custom environment type for unique or unspecified cases.
    CUSTOM = "Custom"

event

event

Event dataclass

Represents an event in the event bus system.

Attributes:

Name Type Description
name str

The name of the event.

payload Any

The data associated with the event.

timestamp datetime

The time the event was created.

priority int

The priority of the event.

Source code in src/aeiva/event/event.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
@dataclass
class Event:
    """
    Represents an event in the event bus system.

    Attributes:
        name (str): The name of the event.
        payload (Any): The data associated with the event.
        timestamp (datetime): The time the event was created.
        priority (int): The priority of the event.
    """
    name: str
    payload: Any = None
    timestamp: datetime = field(default_factory=datetime.utcnow)
    priority: int = 0

event_bus

EventBus

An asynchronous event bus for publishing and subscribing to events.

Features: - Subscribers can use wildcard patterns to subscribe to multiple events. - Subscribers can cancel event propagation. - Subscribers can be set to auto-unsubscribe after one call. - Event-level prioritization in the queue. - Customizable error handling. - Logging for key actions. - emit, emit_after, and emit_only methods for flexible event emission.

Source code in src/aeiva/event/event_bus.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class EventBus:
    """
    An asynchronous event bus for publishing and subscribing to events.

    Features:
    - Subscribers can use wildcard patterns to subscribe to multiple events.
    - Subscribers can cancel event propagation.
    - Subscribers can be set to auto-unsubscribe after one call.
    - Event-level prioritization in the queue.
    - Customizable error handling.
    - Logging for key actions.
    - emit, emit_after, and emit_only methods for flexible event emission.
    """

    def __init__(self):
        """
        Initializes the event bus.
        """
        self._subscribers: List[Dict] = []  # List of subscriber dictionaries
        self._event_queue = asyncio.PriorityQueue()
        self._processing_task: Optional[asyncio.Task] = None
        self._event_counter = 0  # Counter to maintain order of events with same priority
        self.loop = None

    def subscribe(
        self,
        event_pattern: str,
        callback: Callable[[Event], Any],
        *,
        priority: int = 0,
        once: bool = False
    ):
        """
        Subscribes a callback function to events matching a pattern.

        Args:
            event_pattern (str): The event name or pattern to subscribe to.
            callback (Callable[[Event], Any]): The callback function.
            priority (int, optional): Priority of the callback.
            once (bool, optional): If True, unsubscribe after one call.
        """
        subscriber = {
            'pattern': re.compile(event_pattern.replace('*', '.*')),
            'callback': callback,
            'priority': priority,
            'once': once
        }
        self._subscribers.append(subscriber)
        logger.info(f"Subscribed '{callback.__name__}' to pattern '{event_pattern}' with priority {priority}.")

    def unsubscribe(self, callback: Callable[[Event], Any]):
        """
        Unsubscribes a callback function from all events.

        Args:
            callback (Callable[[Event], Any]): The callback function to remove.
        """
        self._subscribers = [
            sub for sub in self._subscribers
            if sub['callback'] != callback
        ]
        logger.info(f"Unsubscribed '{callback.__name__}' from all events.")

    async def publish(self, event: Event, only: Union[str, List[str]] = None):
        """
        Publishes an event to the event bus.

        Args:
            event (Event): The event to publish.
            only (str or List[str], optional): Names of specific subscribers to notify.
        """
        self._event_counter += 1
        # Use a tuple of (priority, counter) to ensure proper ordering
        await self._event_queue.put((event.priority * -1, self._event_counter, event, only))
        logger.info(f"Published event '{event.name}' with priority {event.priority}.")

    async def _process_events(self):
        """
        Internal coroutine that processes events from the queue and dispatches them to subscribers.
        """
        while True:
            try:
                _, _, event, only = await self._event_queue.get()
                logger.info(f"Processing event '{event.name}'.")
                await self._dispatch_event(event, only)
                self._event_queue.task_done()
            except asyncio.CancelledError:
                # Exit the loop gracefully
                break
            except Exception as e:
                logger.error(f"Error processing event: {e}")
                self._event_queue.task_done()

    async def _dispatch_event(self, event: Event, only: Union[str, List[str]] = None):
        """
        Dispatches an event to the appropriate subscribers.

        Args:
            event (Event): The event to dispatch.
            only (str or List[str], optional): Names of specific subscribers to notify.
        """
        subscribers = sorted(
            [
                sub for sub in self._subscribers
                if sub['pattern'].fullmatch(event.name)
                and (only is None or sub['callback'].__name__ in (only if isinstance(only, list) else [only]))
            ],
            key=lambda x: x['priority'],
            reverse=True
        )
        for subscriber in subscribers:
            callback = subscriber['callback']
            try:
                if asyncio.iscoroutinefunction(callback):
                    await callback(event)
                else:
                    await asyncio.get_event_loop().run_in_executor(None, callback, event)
            except EventCancelled:
                logger.info(f"Event '{event.name}' cancelled by '{callback.__name__}'.")
                break  # Stop further propagation
            except Exception as e:
                logger.error(f"Error in callback '{callback.__name__}' for event '{event.name}': {e}")
                self._handle_callback_exception(e, callback, event)
            finally:
                if subscriber.get('once'):
                    self.unsubscribe(callback)

    def _handle_callback_exception(self, exception, callback, event):
        """
        Handle exceptions raised by subscriber callbacks.

        Args:
            exception (Exception): The exception raised.
            callback (Callable): The subscriber callback.
            event (Event): The event being processed.
        """
        # Default behavior is to log the exception.
        pass  # Can be customized as needed.

    def start(self):
        """
        Starts the event bus processing loop.
        """
        if self._processing_task is None:
            self.loop = asyncio.get_running_loop()
            self._processing_task = asyncio.create_task(self._process_events())
            logger.info("Event bus started.")

    def stop(self):
        """
        Stops the event bus processing loop.
        """
        if self._processing_task:
            self._processing_task.cancel()
            logger.info("Event bus stopped.")

    def on(self, event_pattern: str, priority: int = 0, once: bool = False):
        """
        Decorator for subscribing a function to events matching a pattern.

        Usage:
            @event_bus.on('event.*', priority=10)
            async def handler(event):
                ...

        Args:
            event_pattern (str): The event name or pattern to subscribe to.
            priority (int, optional): Priority of the callback.
            once (bool, optional): If True, unsubscribe after one call.

        Returns:
            Callable: The decorator function.
        """
        def decorator(callback: Callable[[Event], Any]):
            self.subscribe(event_pattern, callback, priority=priority, once=once)
            return callback
        return decorator

    def emit_after(self, event_name: str, priority: int = 0):
        """
        Decorator that emits an event after the decorated function is called.

        Usage:
            @event_bus.emit_after('event_name')
            def some_function():
                ...

        Args:
            event_name (str): The name of the event to emit after function execution.
            priority (int, optional): The priority of the event.

        Returns:
            Callable: The decorator function.
        """
        def decorator(func: Callable):
            if asyncio.iscoroutinefunction(func):
                @wraps(func)
                async def async_wrapper(*args, **kwargs):
                    result = await func(*args, **kwargs)
                    await self.emit(event_name, priority=priority)
                    return result
                return async_wrapper
            else:
                @wraps(func)
                def sync_wrapper(*args, **kwargs):
                    result = func(*args, **kwargs)
                    asyncio.create_task(self.emit(event_name, priority=priority))
                    return result
                return sync_wrapper
        return decorator

    async def emit(self, event_name: str, payload: Any = None, priority: int = 0):
        """
        Emits an event to all matching subscribers.

        Args:
            event_name (str): The name of the event to emit.
            payload (Any, optional): The payload of the event.
            priority (int, optional): The priority of the event.
        """
        await self.publish(Event(name=event_name, payload=payload, priority=priority))

    async def emit_only(self, event_name: str, subscriber_names: Union[str, List[str]], payload: Any = None, priority: int = 0):
        """
        Emits an event only to specified subscribers.

        Args:
            event_name (str): The name of the event to emit.
            subscriber_names (str or List[str]): The name(s) of subscribers to notify.
            payload (Any, optional): The payload of the event.
            priority (int, optional): The priority of the event.
        """
        await self.publish(Event(name=event_name, payload=payload, priority=priority), only=subscriber_names)

    async def wait_until_all_events_processed(self):
        """
        Waits until all events in the queue have been processed.
        """
        await self._event_queue.join()
__init__()

Initializes the event bus.

Source code in src/aeiva/event/event_bus.py
32
33
34
35
36
37
38
39
40
def __init__(self):
    """
    Initializes the event bus.
    """
    self._subscribers: List[Dict] = []  # List of subscriber dictionaries
    self._event_queue = asyncio.PriorityQueue()
    self._processing_task: Optional[asyncio.Task] = None
    self._event_counter = 0  # Counter to maintain order of events with same priority
    self.loop = None
emit(event_name, payload=None, priority=0) async

Emits an event to all matching subscribers.

Parameters:

Name Type Description Default
event_name str

The name of the event to emit.

required
payload Any

The payload of the event.

None
priority int

The priority of the event.

0
Source code in src/aeiva/event/event_bus.py
229
230
231
232
233
234
235
236
237
238
async def emit(self, event_name: str, payload: Any = None, priority: int = 0):
    """
    Emits an event to all matching subscribers.

    Args:
        event_name (str): The name of the event to emit.
        payload (Any, optional): The payload of the event.
        priority (int, optional): The priority of the event.
    """
    await self.publish(Event(name=event_name, payload=payload, priority=priority))
emit_after(event_name, priority=0)

Decorator that emits an event after the decorated function is called.

Usage

@event_bus.emit_after('event_name') def some_function(): ...

Parameters:

Name Type Description Default
event_name str

The name of the event to emit after function execution.

required
priority int

The priority of the event.

0

Returns:

Name Type Description
Callable

The decorator function.

Source code in src/aeiva/event/event_bus.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def emit_after(self, event_name: str, priority: int = 0):
    """
    Decorator that emits an event after the decorated function is called.

    Usage:
        @event_bus.emit_after('event_name')
        def some_function():
            ...

    Args:
        event_name (str): The name of the event to emit after function execution.
        priority (int, optional): The priority of the event.

    Returns:
        Callable: The decorator function.
    """
    def decorator(func: Callable):
        if asyncio.iscoroutinefunction(func):
            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                result = await func(*args, **kwargs)
                await self.emit(event_name, priority=priority)
                return result
            return async_wrapper
        else:
            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                result = func(*args, **kwargs)
                asyncio.create_task(self.emit(event_name, priority=priority))
                return result
            return sync_wrapper
    return decorator
emit_only(event_name, subscriber_names, payload=None, priority=0) async

Emits an event only to specified subscribers.

Parameters:

Name Type Description Default
event_name str

The name of the event to emit.

required
subscriber_names str or List[str]

The name(s) of subscribers to notify.

required
payload Any

The payload of the event.

None
priority int

The priority of the event.

0
Source code in src/aeiva/event/event_bus.py
240
241
242
243
244
245
246
247
248
249
250
async def emit_only(self, event_name: str, subscriber_names: Union[str, List[str]], payload: Any = None, priority: int = 0):
    """
    Emits an event only to specified subscribers.

    Args:
        event_name (str): The name of the event to emit.
        subscriber_names (str or List[str]): The name(s) of subscribers to notify.
        payload (Any, optional): The payload of the event.
        priority (int, optional): The priority of the event.
    """
    await self.publish(Event(name=event_name, payload=payload, priority=priority), only=subscriber_names)
on(event_pattern, priority=0, once=False)

Decorator for subscribing a function to events matching a pattern.

Usage

@event_bus.on('event.*', priority=10) async def handler(event): ...

Parameters:

Name Type Description Default
event_pattern str

The event name or pattern to subscribe to.

required
priority int

Priority of the callback.

0
once bool

If True, unsubscribe after one call.

False

Returns:

Name Type Description
Callable

The decorator function.

Source code in src/aeiva/event/event_bus.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def on(self, event_pattern: str, priority: int = 0, once: bool = False):
    """
    Decorator for subscribing a function to events matching a pattern.

    Usage:
        @event_bus.on('event.*', priority=10)
        async def handler(event):
            ...

    Args:
        event_pattern (str): The event name or pattern to subscribe to.
        priority (int, optional): Priority of the callback.
        once (bool, optional): If True, unsubscribe after one call.

    Returns:
        Callable: The decorator function.
    """
    def decorator(callback: Callable[[Event], Any]):
        self.subscribe(event_pattern, callback, priority=priority, once=once)
        return callback
    return decorator
publish(event, only=None) async

Publishes an event to the event bus.

Parameters:

Name Type Description Default
event Event

The event to publish.

required
only str or List[str]

Names of specific subscribers to notify.

None
Source code in src/aeiva/event/event_bus.py
81
82
83
84
85
86
87
88
89
90
91
92
async def publish(self, event: Event, only: Union[str, List[str]] = None):
    """
    Publishes an event to the event bus.

    Args:
        event (Event): The event to publish.
        only (str or List[str], optional): Names of specific subscribers to notify.
    """
    self._event_counter += 1
    # Use a tuple of (priority, counter) to ensure proper ordering
    await self._event_queue.put((event.priority * -1, self._event_counter, event, only))
    logger.info(f"Published event '{event.name}' with priority {event.priority}.")
start()

Starts the event bus processing loop.

Source code in src/aeiva/event/event_bus.py
157
158
159
160
161
162
163
164
def start(self):
    """
    Starts the event bus processing loop.
    """
    if self._processing_task is None:
        self.loop = asyncio.get_running_loop()
        self._processing_task = asyncio.create_task(self._process_events())
        logger.info("Event bus started.")
stop()

Stops the event bus processing loop.

Source code in src/aeiva/event/event_bus.py
166
167
168
169
170
171
172
def stop(self):
    """
    Stops the event bus processing loop.
    """
    if self._processing_task:
        self._processing_task.cancel()
        logger.info("Event bus stopped.")
subscribe(event_pattern, callback, *, priority=0, once=False)

Subscribes a callback function to events matching a pattern.

Parameters:

Name Type Description Default
event_pattern str

The event name or pattern to subscribe to.

required
callback Callable[[Event], Any]

The callback function.

required
priority int

Priority of the callback.

0
once bool

If True, unsubscribe after one call.

False
Source code in src/aeiva/event/event_bus.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def subscribe(
    self,
    event_pattern: str,
    callback: Callable[[Event], Any],
    *,
    priority: int = 0,
    once: bool = False
):
    """
    Subscribes a callback function to events matching a pattern.

    Args:
        event_pattern (str): The event name or pattern to subscribe to.
        callback (Callable[[Event], Any]): The callback function.
        priority (int, optional): Priority of the callback.
        once (bool, optional): If True, unsubscribe after one call.
    """
    subscriber = {
        'pattern': re.compile(event_pattern.replace('*', '.*')),
        'callback': callback,
        'priority': priority,
        'once': once
    }
    self._subscribers.append(subscriber)
    logger.info(f"Subscribed '{callback.__name__}' to pattern '{event_pattern}' with priority {priority}.")
unsubscribe(callback)

Unsubscribes a callback function from all events.

Parameters:

Name Type Description Default
callback Callable[[Event], Any]

The callback function to remove.

required
Source code in src/aeiva/event/event_bus.py
68
69
70
71
72
73
74
75
76
77
78
79
def unsubscribe(self, callback: Callable[[Event], Any]):
    """
    Unsubscribes a callback function from all events.

    Args:
        callback (Callable[[Event], Any]): The callback function to remove.
    """
    self._subscribers = [
        sub for sub in self._subscribers
        if sub['callback'] != callback
    ]
    logger.info(f"Unsubscribed '{callback.__name__}' from all events.")
wait_until_all_events_processed() async

Waits until all events in the queue have been processed.

Source code in src/aeiva/event/event_bus.py
252
253
254
255
256
async def wait_until_all_events_processed(self):
    """
    Waits until all events in the queue have been processed.
    """
    await self._event_queue.join()

EventCancelled

Bases: Exception

Exception to indicate that an event has been cancelled.

Source code in src/aeiva/event/event_bus.py
14
15
16
class EventCancelled(Exception):
    """Exception to indicate that an event has been cancelled."""
    pass

hypergraph

exceptions

HypergraphError

Bases: Exception

Custom exception class for Hypergraph-related errors.

Source code in src/aeiva/hypergraph/exceptions.py
3
4
5
6
7
8
class HypergraphError(Exception):
    """
    Custom exception class for Hypergraph-related errors.
    """
    def __init__(self, message: str = "An error occurred in the Hypergraph module."):
        super().__init__(message)

hyperedge

HyperEdge

Represents a hyperedge in the hypergraph, encapsulating its properties and connected nodes.

Source code in src/aeiva/hypergraph/hyperedge.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class HyperEdge:
    """
    Represents a hyperedge in the hypergraph, encapsulating its properties and connected nodes.
    """

    def __init__(
        self,
        id: Any,
        nodes: Optional[Iterable[Any]] = None,
        properties: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Initializes a HyperEdge.

        Parameters:
            id: Unique identifier for the hyperedge.
            nodes: (Optional) Iterable of node identifiers connected by the hyperedge.
            properties: (Optional) Dictionary of properties.
        """
        self.id: Any = id
        self.nodes: Set[Any] = set(nodes) if nodes else set()
        self.properties: Dict[str, Any] = properties.copy() if properties else {}

    def add_node(self, node_id: Any) -> None:
        """
        Adds a node to the hyperedge.

        Parameters:
            node_id: Identifier of the node to add.
        """
        self.nodes.add(node_id)

    def remove_node(self, node_id: Any) -> None:
        """
        Removes a node from the hyperedge.

        Parameters:
            node_id: Identifier of the node to remove.
        """
        if node_id in self.nodes:
            self.nodes.remove(node_id)
        else:
            raise HypergraphError(f"Node '{node_id}' not found in HyperEdge '{self.id}'.")

    def add_property(self, key: str, value: Any) -> None:
        """
        Adds or updates a property of the hyperedge.

        Parameters:
            key: Property name.
            value: Property value.
        """
        self.properties[key] = value

    def get_property(self, key: str) -> Any:
        """
        Retrieves a property of the hyperedge.

        Parameters:
            key: Property name.

        Returns:
            The value of the property.

        Raises:
            HypergraphError: If the property does not exist.
        """
        if key in self.properties:
            return self.properties[key]
        else:
            raise HypergraphError(f"Property '{key}' does not exist for HyperEdge '{self.id}'.")

    def remove_property(self, key: str) -> None:
        """
        Removes a property from the hyperedge.

        Parameters:
            key: Property name.

        Raises:
            HypergraphError: If the property does not exist.
        """
        if key in self.properties:
            del self.properties[key]
        else:
            raise HypergraphError(f"Property '{key}' does not exist for HyperEdge '{self.id}'.")

    def to_dict(self):
        return {
            "id": self.id,
            "nodes": self.nodes,
            "properties": self.properties
        }
__init__(id, nodes=None, properties=None)

Initializes a HyperEdge.

Parameters:

Name Type Description Default
id Any

Unique identifier for the hyperedge.

required
nodes Optional[Iterable[Any]]

(Optional) Iterable of node identifiers connected by the hyperedge.

None
properties Optional[Dict[str, Any]]

(Optional) Dictionary of properties.

None
Source code in src/aeiva/hypergraph/hyperedge.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(
    self,
    id: Any,
    nodes: Optional[Iterable[Any]] = None,
    properties: Optional[Dict[str, Any]] = None
) -> None:
    """
    Initializes a HyperEdge.

    Parameters:
        id: Unique identifier for the hyperedge.
        nodes: (Optional) Iterable of node identifiers connected by the hyperedge.
        properties: (Optional) Dictionary of properties.
    """
    self.id: Any = id
    self.nodes: Set[Any] = set(nodes) if nodes else set()
    self.properties: Dict[str, Any] = properties.copy() if properties else {}
add_node(node_id)

Adds a node to the hyperedge.

Parameters:

Name Type Description Default
node_id Any

Identifier of the node to add.

required
Source code in src/aeiva/hypergraph/hyperedge.py
30
31
32
33
34
35
36
37
def add_node(self, node_id: Any) -> None:
    """
    Adds a node to the hyperedge.

    Parameters:
        node_id: Identifier of the node to add.
    """
    self.nodes.add(node_id)
add_property(key, value)

Adds or updates a property of the hyperedge.

Parameters:

Name Type Description Default
key str

Property name.

required
value Any

Property value.

required
Source code in src/aeiva/hypergraph/hyperedge.py
51
52
53
54
55
56
57
58
59
def add_property(self, key: str, value: Any) -> None:
    """
    Adds or updates a property of the hyperedge.

    Parameters:
        key: Property name.
        value: Property value.
    """
    self.properties[key] = value
get_property(key)

Retrieves a property of the hyperedge.

Parameters:

Name Type Description Default
key str

Property name.

required

Returns:

Type Description
Any

The value of the property.

Raises:

Type Description
HypergraphError

If the property does not exist.

Source code in src/aeiva/hypergraph/hyperedge.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def get_property(self, key: str) -> Any:
    """
    Retrieves a property of the hyperedge.

    Parameters:
        key: Property name.

    Returns:
        The value of the property.

    Raises:
        HypergraphError: If the property does not exist.
    """
    if key in self.properties:
        return self.properties[key]
    else:
        raise HypergraphError(f"Property '{key}' does not exist for HyperEdge '{self.id}'.")
remove_node(node_id)

Removes a node from the hyperedge.

Parameters:

Name Type Description Default
node_id Any

Identifier of the node to remove.

required
Source code in src/aeiva/hypergraph/hyperedge.py
39
40
41
42
43
44
45
46
47
48
49
def remove_node(self, node_id: Any) -> None:
    """
    Removes a node from the hyperedge.

    Parameters:
        node_id: Identifier of the node to remove.
    """
    if node_id in self.nodes:
        self.nodes.remove(node_id)
    else:
        raise HypergraphError(f"Node '{node_id}' not found in HyperEdge '{self.id}'.")
remove_property(key)

Removes a property from the hyperedge.

Parameters:

Name Type Description Default
key str

Property name.

required

Raises:

Type Description
HypergraphError

If the property does not exist.

Source code in src/aeiva/hypergraph/hyperedge.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def remove_property(self, key: str) -> None:
    """
    Removes a property from the hyperedge.

    Parameters:
        key: Property name.

    Raises:
        HypergraphError: If the property does not exist.
    """
    if key in self.properties:
        del self.properties[key]
    else:
        raise HypergraphError(f"Property '{key}' does not exist for HyperEdge '{self.id}'.")

hypergraph

Hypergraph

A simplified Hypergraph class using dictionaries and NetworkX for management.

Parameters

hyperedges : Dict[Any, Dict[str, Any]] A dictionary where keys are hyperedge identifiers and values are dictionaries containing: - 'nodes': Iterable of node identifiers connected by the hyperedge. - 'properties': (Optional) Dictionary of properties for the hyperedge.

Optional[Dict[Any, Dict[str, Any]]] = None

A dictionary where keys are node identifiers and values are dictionaries of node properties.

Optional[Dict[Any, Dict[str, Any]]] = None

A dictionary where keys are hyperedge identifiers and values are dictionaries of hyperedge properties.

Optional[str] = None

Name assigned to the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
class Hypergraph:
    """
    A simplified Hypergraph class using dictionaries and NetworkX for management.

    Parameters
    ----------
    hyperedges : Dict[Any, Dict[str, Any]]
        A dictionary where keys are hyperedge identifiers and values are dictionaries containing:
            - 'nodes': Iterable of node identifiers connected by the hyperedge.
            - 'properties': (Optional) Dictionary of properties for the hyperedge.

    node_properties : Optional[Dict[Any, Dict[str, Any]]] = None
        A dictionary where keys are node identifiers and values are dictionaries of node properties.

    hyperedge_properties : Optional[Dict[Any, Dict[str, Any]]] = None
        A dictionary where keys are hyperedge identifiers and values are dictionaries of hyperedge properties.

    name : Optional[str] = None
        Name assigned to the hypergraph.
    """

    def __init__(
        self,
        hyperedges: Dict[Any, Dict[str, Any]],
        node_properties: Optional[Dict[Any, Dict[str, Any]]] = None,
        hyperedge_properties: Optional[Dict[Any, Dict[str, Any]]] = None,
        name: Optional[str] = None
    ):
        self.name = name
        self.graph = nx.Graph()
        self.bipartite_nodes: Set[Any] = set()

        # Initialize node and hyperedge properties using deep copies to ensure full duplication
        self.node_properties: Dict[Any, Dict[str, Any]] = copy.deepcopy(node_properties) if node_properties else {}
        self.hyperedge_properties: Dict[Any, Dict[str, Any]] = copy.deepcopy(hyperedge_properties) if hyperedge_properties else {}

        # Add hyperedges and their connections to nodes
        self.hyperedges: Dict[Any, HyperEdge] = {}
        for he_id, he_data in hyperedges.items():
            nodes = he_data.get('nodes', [])
            properties = he_data.get('properties', {})
            hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)
            self.hyperedges[he_id] = hyperedge

            # Add hyperedge to bipartite graph with properties
            self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties.get(he_id, {}))
            self.bipartite_nodes.add(he_id)

            # Add edges between hyperedge and nodes with node properties
            for node in hyperedge.nodes:
                if node not in self.graph:
                    self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))
                self.graph.add_edge(he_id, node)

    def dual(self, name: Optional[str] = None) -> "Hypergraph":
        """
        Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.

        Parameters
        ----------
        name : Optional[str], default=None
            Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.

        Returns
        -------
        Hypergraph
            A new Hypergraph instance representing the dual of the current hypergraph.
        """
        # Initialize dual hyperedges, which will correspond to original nodes
        dual_hyperedges = {}

        # Invert the node-hyperedge structure
        for he_id, hyperedge in self.hyperedges.items():
            for node in hyperedge.nodes:
                # Each original node becomes a hyperedge in the dual
                if node not in dual_hyperedges:
                    dual_hyperedges[node] = {'nodes': [], 'properties': self.node_properties.get(node, {})}
                # The new hyperedge (original node) connects to the original hyperedge id as a "node"
                dual_hyperedges[node]['nodes'].append(he_id)

        # Define node properties in the dual as the original hyperedge properties
        dual_node_properties = {he_id: self.hyperedge_properties.get(he_id, {}) for he_id in self.hyperedges}

        # Create and return the dual Hypergraph
        return Hypergraph(
            hyperedges=dual_hyperedges,
            node_properties=dual_node_properties,
            hyperedge_properties=self.node_properties,  # Properties of original nodes now apply to dual hyperedges
            name=name or (self.name + "_dual" if self.name else "dual")
        )

    def nodes(self) -> List[Any]:
        """
        Returns a list of all unique node identifiers in the hypergraph.

        Returns
        -------
        List[Any]
            List of node IDs.
        """
        return list(self.node_properties.keys())

    def node_memberships(self) -> Dict[Any, List[Any]]:
        """
        Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.

        Returns
        -------
        Dict[Any, List[Any]]
            Dictionary mapping node IDs to the hyperedge IDs they belong to.
        """
        memberships = {}
        for he_id, hyperedge in self.hyperedges.items():
            for node in hyperedge.nodes:
                memberships.setdefault(node, []).append(he_id)
        return memberships

    def edges(self) -> List[Any]:
        """
        Returns a list of all hyperedge identifiers in the hypergraph.

        Returns
        -------
        List[Any]
            List of hyperedge IDs.
        """
        return list(self.hyperedges.keys())

    def edge_elements(self) -> Dict[Any, List[Any]]:
        """
        Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.

        Returns
        -------
        Dict[Any, List[Any]]
            Dictionary mapping hyperedge IDs to lists of node IDs they contain.
        """
        return {he_id: hyperedge.nodes for he_id, hyperedge in self.hyperedges.items()}

    def __str__(self) -> str:
        """
        String representation of the hypergraph.

        Returns
        -------
        str
            A string describing the hypergraph with its name, number of nodes, and hyperedges.
        """
        return f"Hypergraph '{self.name}' with {len(self)} nodes and {len(self.hyperedges)} hyperedges."

    def __repr__(self) -> str:
        """
        Official string representation of the hypergraph.

        Returns
        -------
        str
            A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.
        """
        return (
            f"Hypergraph(name={self.name!r}, "
            f"nodes={len(self)}, hyperedges={len(self.hyperedges)})"
        )

    def __len__(self) -> int:
        """
        Returns the number of nodes in the hypergraph.

        Returns
        -------
        int
            Number of nodes.
        """
        return len(self.node_properties)

    def __iter__(self) -> Iterator[Any]:
        """
        Allows iteration over the nodes of the hypergraph.

        Yields
        ------
        Any
            Node identifiers.
        """
        return iter(self.node_properties)

    def __contains__(self, item: Any) -> bool:
        """
        Checks if a node is in the hypergraph.

        Parameters
        ----------
        item : Any
            The node identifier to check.

        Returns
        -------
        bool
            True if the node exists in the hypergraph, False otherwise.
        """
        return item in self.node_properties

    def __getitem__(self, node: Any) -> Iterable[Any]:
        """
        Retrieves the neighbors of a node in the hypergraph.

        Neighbors are nodes that share at least one hyperedge with the given node.

        Parameters
        ----------
        node : Any
            The node identifier.

        Returns
        -------
        Iterable[Any]
            An iterator over neighboring node identifiers.

        Raises
        ------
        HypergraphError
            If the node does not exist in the hypergraph.
        """
        if node not in self.node_properties:
            raise HypergraphError(f"Node '{node}' does not exist in the hypergraph.")

        # Get all hyperedges that include the node
        hyperedges = set(self.graph.neighbors(node))

        # Get all nodes connected by these hyperedges
        neighbors = set()
        for he_id in hyperedges:
            neighbors.update(self.hyperedges[he_id].nodes)

        neighbors.discard(node)  # Remove the node itself
        return neighbors

    def __eq__(self, other: Any) -> bool:
        """
        Checks if two hypergraphs are equal based on their hyperedges and nodes.

        Parameters
        ----------
        other : Any
            The other object to compare.

        Returns
        -------
        bool
            True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.
        """
        if not isinstance(other, Hypergraph):
            return False

        # Compare nodes and their properties
        if self.node_properties != other.node_properties:
            return False

        # Compare hyperedges and their properties
        if self.hyperedges.keys() != other.hyperedges.keys():
            return False

        for he_id in self.hyperedges:
            if self.hyperedges[he_id].nodes != other.hyperedges[he_id].nodes:
                return False
            if self.hyperedge_properties.get(he_id, {}) != other.hyperedge_properties.get(he_id, {}):
                return False

        return True

    def copy(self, name: Optional[str] = None) -> 'Hypergraph':
        """
        Creates a deep copy of the hypergraph instance.

        Parameters
        ----------
        name : Optional[str], default=None
            The name for the copied Hypergraph. If not provided, retains the original name.

        Returns
        -------
        Hypergraph
            A new Hypergraph instance that is a deep copy of the original.
        """

        # Deep copy hyperedges
        hyperedges_dict = {}
        for he_id, he in self.hyperedges.items():
            hyperedges_dict[he_id] = {
                'nodes': list(he.nodes),
                'properties': copy.deepcopy(he.properties)
            }

        # Deep copy node_properties and hyperedge_properties
        node_properties_copy = copy.deepcopy(self.node_properties)
        hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)

        # Create a new Hypergraph instance with the copied data
        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=node_properties_copy,
            hyperedge_properties=hyperedge_properties_copy,
            name=name if name is not None else self.name
        )

    def deepcopy(self, name: Optional[str] = None) -> 'Hypergraph':
        """
        Creates a deep copy of the hypergraph.

        Parameters
        ----------
        name : Optional[str], default=None
            The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.

        Returns
        -------
        Hypergraph
            A deep copy of the hypergraph.
        """

        # Deep copy hyperedges
        hyperedges_copy = {
            he_id: {
                'nodes': hyperedge.nodes.copy(),
                'properties': copy.deepcopy(hyperedge.properties)
            }
            for he_id, hyperedge in self.hyperedges.items()
        }

        # Deep copy node properties
        node_properties_copy = copy.deepcopy(self.node_properties)

        # Deep copy hyperedge properties
        hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)

        # Set name
        cloned_name = f"{self.name}_deepcopy" if name is None else name

        # Initialize the cloned hypergraph
        cloned_H = Hypergraph(
            hyperedges=hyperedges_copy,
            node_properties=node_properties_copy,
            hyperedge_properties=hyperedge_properties_copy,
            name=cloned_name
        )

        return cloned_H

    # Adding and Removing Hyperedges and Nodes

    def add_hyperedge(
        self,
        he_id: Any,
        nodes: Iterable[Any],
        properties: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Adds a hyperedge to the hypergraph.

        Parameters
        ----------
        he_id : Any
            Unique identifier for the hyperedge.
        nodes : Iterable[Any]
            Nodes connected by the hyperedge.
        properties : Optional[Dict[str, Any]] = None
            Properties of the hyperedge.

        Raises
        ------
        HypergraphError
            If the hyperedge ID already exists.
        """
        if he_id in self.hyperedges:
            raise HypergraphError(f"Hyperedge '{he_id}' already exists.")

        hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)
        self.hyperedges[he_id] = hyperedge
        self.hyperedge_properties[he_id] = copy.deepcopy(properties) if properties else {}

        # Add hyperedge to bipartite graph
        self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties[he_id])
        self.bipartite_nodes.add(he_id)

        # Add edges between hyperedge and nodes
        for node in hyperedge.nodes:
            if node not in self.graph:
                self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))
            self.graph.add_edge(he_id, node)

    def remove_hyperedge(self, he_id: Any) -> None:
        """
        Removes a hyperedge from the hypergraph.

        Parameters
        ----------
        he_id : Any
            Identifier of the hyperedge to remove.

        Raises
        ------
        HypergraphError
            If the hyperedge does not exist.
        """
        if he_id not in self.hyperedges:
            raise HypergraphError(f"Hyperedge '{he_id}' does not exist.")

        # Remove hyperedge from the graph, which also removes all incidences
        self.graph.remove_node(he_id)
        self.bipartite_nodes.discard(he_id)

        # Remove from internal structures
        del self.hyperedges[he_id]
        self.hyperedge_properties.pop(he_id, None)

    def add_hyperedges_from(
        self,
        hyperedges: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Adds multiple hyperedges with attributes to the hypergraph.

        Parameters
        ----------
        hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]
            An iterable of hyperedge identifiers or tuples of (he_id, attributes).
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If any hyperedge ID already exists.
        ValueError
            If any tuple does not contain exactly two elements or if attributes are not dictionaries.
        """
        new_hyperedges = []
        for item in hyperedges:
            if isinstance(item, tuple):
                if len(item) != 2 or not isinstance(item[1], dict):
                    raise ValueError(f"Each tuple must be of the form (he_id, attributes). Invalid tuple: {item}")
                he_id, attrs = item
            else:
                he_id, attrs = item, {}

            if he_id in self.hyperedges:
                raise HypergraphError(f"Hyperedge '{he_id}' already exists.")

            hyperedge = HyperEdge(id=he_id, nodes=[], properties=attrs.copy())
            new_hyperedges.append(hyperedge)

        if inplace:
            for hyperedge in new_hyperedges:
                self.hyperedges[hyperedge.id] = hyperedge
                self.hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)
                self.graph.add_node(hyperedge.id, bipartite='hyperedge', **self.hyperedge_properties[hyperedge.id])
                self.bipartite_nodes.add(hyperedge.id)
            return self
        else:
            # Create a new Hypergraph instance with added hyperedges
            new_hyperedges_dict = copy.deepcopy(self.hyperedges)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            for hyperedge in new_hyperedges:
                new_hyperedges_dict[hyperedge.id] = hyperedge
                new_hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)
                new_graph.add_node(hyperedge.id, bipartite='hyperedge', **new_hyperedge_properties[hyperedge.id])
                new_bipartite_nodes.add(hyperedge.id)

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges_dict.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def add_node(
        self,
        node_id: Any,
        properties: Optional[Dict[str, Any]] = None,
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Adds a node to the hypergraph.

        Parameters
        ----------
        node_id : Any
            Identifier for the node.
        properties : Optional[Dict[str, Any]] = None
            Properties of the node.
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If the node ID already exists.
        """
        if node_id in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' already exists in the hypergraph.")

        if inplace:
            self.node_properties[node_id] = copy.deepcopy(properties) if properties else {}
            self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])
            return self
        else:
            # Create a new Hypergraph instance with the added node
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            new_node_properties[node_id] = copy.deepcopy(properties) if properties else {}
            new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])

            return Hypergraph(
                hyperedges={
                    he_id: {
                        'nodes': list(he.nodes),
                        'properties': he.properties.copy()
                    } for he_id, he in new_hyperedges.items()
                },
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def remove_node(self, node_id: Any, inplace: bool = True) -> 'Hypergraph':
        """
        Removes a node from the hypergraph.

        Parameters
        ----------
        node_id : Any
            Identifier of the node to remove.
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If the node does not exist.
        """
        if node_id not in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")

        if inplace:
            # Remove node from node_properties
            del self.node_properties[node_id]
            # Remove node from all hyperedges
            for hyperedge in self.hyperedges.values():
                if node_id in hyperedge.nodes:
                    hyperedge.remove_node(node_id)
            # Remove node from graph, which also removes all incidences
            self.graph.remove_node(node_id)
            return self
        else:
            # Create a new Hypergraph instance with the node removed
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            # Remove node from node_properties
            del new_node_properties[node_id]
            # Remove node from all hyperedges
            for hyperedge in new_hyperedges.values():
                if node_id in hyperedge.nodes:
                    hyperedge.remove_node(node_id)
            # Remove node from graph, which also removes all incidences
            new_graph.remove_node(node_id)

            # Remove nodes not connected to any hyperedges
            retained_nodes = set()
            for hyperedge in new_hyperedges.values():
                retained_nodes.update(hyperedge.nodes)

            new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def add_nodes_from(
        self,
        nodes: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Adds multiple nodes with attributes to the hypergraph.

        Parameters
        ----------
        nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]
            An iterable of node identifiers or tuples of (node_id, attributes).
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If any node ID already exists.
        ValueError
            If any tuple does not contain exactly two elements or if attributes are not dictionaries.
        """
        new_nodes = {}
        for item in nodes:
            if isinstance(item, tuple):
                if len(item) != 2 or not isinstance(item[1], dict):
                    raise ValueError(f"Each tuple must be of the form (node_id, attributes). Invalid tuple: {item}")
                node_id, attrs = item
            else:
                node_id, attrs = item, {}

            if node_id in self.node_properties:
                raise HypergraphError(f"Node '{node_id}' already exists in the hypergraph.")

            new_nodes[node_id] = copy.deepcopy(attrs)

        if inplace:
            for node_id, attrs in new_nodes.items():
                self.node_properties[node_id] = attrs
                self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])
            return self
        else:
            # Create a new Hypergraph instance with the added nodes
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            for node_id, attrs in new_nodes.items():
                new_node_properties[node_id] = attrs
                new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])

            return Hypergraph(
                hyperedges={
                    he_id: {
                        'nodes': list(he.nodes),
                        'properties': he.properties.copy()
                    } for he_id, he in new_hyperedges.items()
                },
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def remove_hyperedges(self, he_ids: Union[Any, Iterable[Any]], inplace: bool = True) -> 'Hypergraph':
        """
        Removes the specified hyperedges from the hypergraph.

        Parameters
        ----------
        he_ids : Any | Iterable[Any]
            Hyperedge identifier(s) to remove.
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If any hyperedge ID does not exist.
        """
        if isinstance(he_ids, (str, int)):
            he_ids = [he_ids]
        else:
            he_ids = list(he_ids)

        non_existing = set(he_ids) - set(self.hyperedges.keys())
        if non_existing:
            raise HypergraphError(f"Hyperedges {non_existing} do not exist in the hypergraph.")

        if inplace:
            for he_id in he_ids:
                self.remove_hyperedge(he_id)
            return self
        else:
            # Create a new Hypergraph instance with hyperedges removed
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            for he_id in he_ids:
                del new_hyperedges[he_id]
                new_hyperedge_properties.pop(he_id, None)
                new_graph.remove_node(he_id)
                new_bipartite_nodes.discard(he_id)

            # Remove nodes not connected to any hyperedges
            retained_nodes = set()
            for hyperedge in new_hyperedges.values():
                retained_nodes.update(hyperedge.nodes)

            new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def remove_nodes_from(
        self,
        nodes: Union[Any, Iterable[Any]],
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Removes the specified nodes from the hypergraph.

        Parameters
        ----------
        nodes : Any | Iterable[Any]
            Node identifier(s) to remove.
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If any node ID does not exist.
        """
        if isinstance(nodes, (str, int)):
            nodes = [nodes]
        else:
            nodes = list(nodes)

        non_existing = set(nodes) - set(self.node_properties.keys())
        if non_existing:
            raise HypergraphError(f"Nodes {non_existing} do not exist in the hypergraph.")

        if inplace:
            for node_id in nodes:
                self.remove_node(node_id)
            return self
        else:
            # Create a new Hypergraph instance with nodes removed
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            for node_id in nodes:
                del new_node_properties[node_id]
                # Remove node from all hyperedges
                for hyperedge in new_hyperedges.values():
                    if node_id in hyperedge.nodes:
                        hyperedge.remove_node(node_id)
                # Remove node from graph, which also removes all incidences
                new_graph.remove_node(node_id)

            # Remove nodes not connected to any hyperedges
            retained_nodes = set()
            for hyperedge in new_hyperedges.values():
                retained_nodes.update(hyperedge.nodes)

            new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def add_incidence(
        self,
        he_id: Any,
        node_id: Any,
        attributes: Optional[Dict[str, Any]] = None,
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Adds a single incidence with attributes to the hypergraph.

        Parameters
        ----------
        he_id : Any
            Identifier of the hyperedge.
        node_id : Any
            Identifier of the node.
        attributes : Optional[Dict[str, Any]] = None
            Properties to add to the incidence as key-value pairs.
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If the hyperedge or node does not exist, or if the incidence already exists.
        """
        if he_id not in self.hyperedges:
            raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
        if node_id not in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
        if node_id in self.hyperedges[he_id].nodes:
            raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.")

        if inplace:
            # Add node to HyperEdge's nodes
            self.hyperedges[he_id].add_node(node_id)
            # Update hyperedge_properties if attributes provided
            if attributes:
                self.hyperedge_properties[he_id].update(attributes)
            # Add edge in graph with attributes
            self.graph.add_edge(he_id, node_id, **(attributes if attributes else {}))
            return self
        else:
            # Create a new Hypergraph instance with the incidence added
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            # Add node to HyperEdge's nodes
            new_hyperedges[he_id].add_node(node_id)
            # Update hyperedge_properties if attributes provided
            if attributes:
                new_hyperedge_properties[he_id].update(attributes)
            # Add edge in graph with attributes
            new_graph.add_edge(he_id, node_id, **(attributes if attributes else {}))

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def remove_incidence(
        self,
        he_id: Any,
        node_id: Any,
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Removes a single incidence from the hypergraph.

        Parameters
        ----------
        he_id : Any
            Identifier of the hyperedge.
        node_id : Any
            Identifier of the node.
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If the hyperedge or node does not exist, or if the incidence does not exist.
        """
        if he_id not in self.hyperedges:
            raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
        if node_id not in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
        if node_id not in self.hyperedges[he_id].nodes:
            raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.")

        if inplace:
            # Remove node from HyperEdge's nodes
            self.hyperedges[he_id].remove_node(node_id)
            # Remove edge from graph
            self.graph.remove_edge(he_id, node_id)
            return self
        else:
            # Create a new Hypergraph instance with the incidence removed
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            # Remove node from HyperEdge's nodes
            new_hyperedges[he_id].remove_node(node_id)
            # Remove edge from graph
            new_graph.remove_edge(he_id, node_id)

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    # Managing Properties and Incidences

    def adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:
        """
        Generates the adjacency matrix for nodes based on s-node connectivity.
        """
        from scipy.sparse import lil_matrix

        node_ids = list(self.node_properties.keys())
        node_index = {node_id: idx for idx, node_id in enumerate(node_ids)}
        size = len(node_ids)
        if size == 0:
            return None, {}

        A = lil_matrix((size, size), dtype=int)
        for he in self.hyperedges.values():
            nodes = list(he.nodes)
            for i in range(len(nodes)):
                for j in range(i + 1, len(nodes)):
                    A[node_index[nodes[i]], node_index[nodes[j]]] += 1

        # Apply the threshold s and convert to binary
        A = (A >= s).astype(int)
        A = A.tocsr()

        if index:
            return A, node_index
        return A, {}

    def hyperedge_adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:
        """
        Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.

        Parameters
        ----------
        s : int, optional, default=1
            The number of shared nodes required for hyperedges to be considered adjacent.
        index : bool, optional, default=False
            If True, returns a mapping from matrix indices to hyperedge IDs.

        Returns
        -------
        Tuple[Optional[csr_matrix], Dict[int, Any]]
            - The adjacency matrix in CSR format.
            - A dictionary mapping matrix indices to hyperedge IDs.
        """
        from scipy.sparse import lil_matrix

        hyperedge_ids = list(self.hyperedges.keys())
        he_index = {he_id: idx for idx, he_id in enumerate(hyperedge_ids)}
        size = len(hyperedge_ids)
        if size == 0:
            return None, {}

        A = lil_matrix((size, size), dtype=int)
        for i, he1 in enumerate(hyperedge_ids):
            nodes1 = self.hyperedges[he1].nodes
            for j in range(i + 1, size):
                he2 = hyperedge_ids[j]
                nodes2 = self.hyperedges[he2].nodes
                shared_nodes = nodes1 & nodes2
                if len(shared_nodes) >= s:
                    A[i, j] = 1
                    A[j, i] = 1

        A = A.tocsr()

        if index:
            return A, he_index
        return A, {}

    def get_hyperedges_of_node(self, node_id: Any) -> Set[Any]:
        """
        Retrieves all hyperedges that a given node is part of.

        Parameters
        ----------
        node_id : Any
            The node identifier.

        Returns
        -------
        Set[Any]
            A set of hyperedge IDs that the node belongs to.

        Raises
        ------
        HypergraphError
            If the node does not exist in the hypergraph.
        """
        if node_id not in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
        return {he.id for he in self.hyperedges.values() if node_id in he.nodes}

    def collapse_duplicate_hyperedges(
        self,
        name: Optional[str] = None,
        use_uids: Optional[List[Any]] = None,
        use_counts: bool = False,
        return_counts: bool = True,
        return_equivalence_classes: bool = False,
        aggregate_properties_by: Optional[Dict[str, str]] = None,
    ) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:
        """
        Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.

        Parameters
        ----------
        name : Optional[str], default=None
            The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.

        use_uids : Optional[List[Any]] = None
            Specifies the hyperedge identifiers to use as representatives for each equivalence class.
            If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.
            If None, the first encountered hyperedge in each class is used as the representative.

        use_counts : bool, optional, default=False
            If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').

        return_counts : bool, optional, default=True
            If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.

        return_equivalence_classes : bool, optional, default=False
            If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.

        aggregate_properties_by : Optional[Dict[str, str]] = None
            A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).
            Properties not specified will use the 'first' aggregation.

        Returns
        -------
        Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]
            - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.
            - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.

        Raises
        ------
        HypergraphError
            If the hypergraph is empty or improperly structured.
        """
        if not self.hyperedges:
            raise HypergraphError("Cannot collapse hyperedges in an empty hypergraph.")

        # Identify equivalence classes based on identical node memberships
        membership_to_hyperedges: Dict[frozenset, Set[Any]] = {}
        for he_id, hyperedge in self.hyperedges.items():
            key = frozenset(hyperedge.nodes)
            membership_to_hyperedges.setdefault(key, set()).add(he_id)

        # Filter out classes with only one hyperedge (no duplicates)
        equivalence_classes = [hes for hes in membership_to_hyperedges.values() if len(hes) > 1]
        if not equivalence_classes:
            # No duplicates to collapse; return the original hypergraph
            return self if not return_equivalence_classes else (self, {})

        # Prepare aggregation methods
        aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {"weight": "sum"}

        # Initialize mapping from old hyperedges to new hyperedges
        hyperedge_mapping: Dict[Any, Any] = {}
        equivalence_class_dict: Dict[Any, Set[Any]] = {}

        for eq_class in equivalence_classes:
            # Determine representative
            if use_uids:
                # Select the first UID from use_uids that is in the equivalence class
                representative = next((uid for uid in use_uids if uid in eq_class), None)
                if not representative:
                    # Fallback to the first hyperedge in the equivalence class
                    representative = next(iter(eq_class))
            else:
                # Use the first hyperedge in the equivalence class as representative
                representative = next(iter(eq_class))

            # Optionally rename with counts
            if use_counts:
                new_representative = f"{representative}:{len(eq_class)}"
            else:
                new_representative = representative

            # Map all hyperedges in the class to the representative
            for he in eq_class:
                hyperedge_mapping[he] = new_representative

            # Store the equivalence class
            equivalence_class_dict[new_representative] = eq_class

        # Replace hyperedge IDs in incidences based on mapping
        new_hyperedges = {}
        for he_id, hyperedge in self.hyperedges.items():
            new_he_id = hyperedge_mapping.get(he_id, he_id)
            if new_he_id not in new_hyperedges:
                new_hyperedges[new_he_id] = HyperEdge(id=new_he_id, nodes=hyperedge.nodes.copy(), properties=copy.deepcopy(hyperedge.properties))
            else:
                new_hyperedges[new_he_id].nodes.update(hyperedge.nodes)

        # Aggregate hyperedge properties
        for he_id, hyperedge in new_hyperedges.items():
            if he_id in equivalence_class_dict:
                aggregated_props = {}
                for prop, agg_func in aggregate_properties_by.items():
                    values = [self.hyperedge_properties[old_he].get(prop, 0) for old_he in equivalence_class_dict[he_id]]
                    if agg_func == 'sum':
                        aggregated_props[prop] = sum(values)
                    elif agg_func == 'mean':
                        aggregated_props[prop] = sum(values) / len(values) if values else 0
                    elif agg_func == 'max':
                        aggregated_props[prop] = max(values) if values else None
                    elif agg_func == 'min':
                        aggregated_props[prop] = min(values) if values else None
                    else:
                        aggregated_props[prop] = values[0] if values else None  # Default to first
                new_hyperedges[he_id].properties.update(aggregated_props)

        # Handle equivalence class size
        if use_counts:
            for he_id in equivalence_class_dict:
                new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])
        elif return_counts:
            for he_id in new_hyperedges:
                if he_id in equivalence_class_dict:
                    new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])
                else:
                    new_hyperedges[he_id].properties['equivalence_class_size'] = 1

        # Initialize the collapsed hypergraph
        collapsed_hypergraph = Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            },
            node_properties=copy.deepcopy(self.node_properties),
            hyperedge_properties={
                he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()
            },
            name=name if name else f"{self.name}_collapsed_hyperedges"
        )

        if return_equivalence_classes:
            return collapsed_hypergraph, equivalence_class_dict
        else:
            return collapsed_hypergraph

    def restrict_to_specific_hyperedges(
        self,
        hyperedges_to_retain: Iterable[Any],
        name: Optional[str] = None
    ) -> 'Hypergraph':
        """
        Creates a new hypergraph by retaining only the specified hyperedges and removing all others.

        Parameters
        ----------
        hyperedges_to_retain : Iterable[Any]
            An iterable of hyperedge identifiers to retain in the new hypergraph.

        name : Optional[str], default=None
            The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.

        Returns
        -------
        Hypergraph
            A new hypergraph containing only the specified hyperedges and their associated nodes.

        Raises
        ------
        HypergraphError
            If none of the specified hyperedges exist in the hypergraph.
        """
        hyperedges_to_retain = set(hyperedges_to_retain)
        existing_hyperedges = set(self.hyperedges.keys())
        invalid_hyperedges = hyperedges_to_retain - existing_hyperedges
        if invalid_hyperedges:
            raise HypergraphError(f"The following hyperedges do not exist and cannot be retained: {invalid_hyperedges}")

        # Determine hyperedges to remove
        hyperedges_to_remove = existing_hyperedges - hyperedges_to_retain
        if not hyperedges_to_remove:
            # No hyperedges to remove; return the original hypergraph
            return self

        # Remove hyperedges using the existing remove_hyperedges method
        restricted_hypergraph = self.remove_hyperedges(hyperedges_to_remove, inplace=False)
        restricted_hypergraph.name = name if name else f"{self.name}_restricted_hyperedges"

        return restricted_hypergraph

    def restrict_to_specific_nodes(
        self,
        nodes_to_retain: Iterable[Any],
        name: Optional[str] = None
    ) -> 'Hypergraph':
        """
        Creates a new hypergraph by retaining only the specified nodes and removing all others.

        Parameters
        ----------
        nodes_to_retain : Iterable[Any]
            An iterable of node identifiers to retain in the new hypergraph.

        name : Optional[str], default=None
            The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.

        Returns
        -------
        Hypergraph
            A new hypergraph containing only the specified nodes and their associated hyperedges.

        Raises
        ------
        HypergraphError
            If none of the specified nodes exist in the hypergraph.
        """
        nodes_to_retain = set(nodes_to_retain)
        existing_nodes = set(self.node_properties.keys())
        invalid_nodes = nodes_to_retain - existing_nodes
        if invalid_nodes:
            raise HypergraphError(f"The following nodes do not exist and cannot be retained: {invalid_nodes}")

        # Determine nodes to remove
        nodes_to_remove = existing_nodes - nodes_to_retain
        if not nodes_to_remove:
            # No nodes to remove; return the original hypergraph
            return self

        # Remove nodes using the existing remove_nodes_from method
        restricted_hypergraph = self.remove_nodes_from(nodes_to_remove, inplace=False)
        restricted_hypergraph.name = name if name else f"{self.name}_restricted_nodes"

        return restricted_hypergraph

    def add_incidences_from(
        self,
        incidences: Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]],
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Adds a collection of incidences to the hypergraph.

        Parameters
        ----------
        incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]]
            Incidence tuples as:
                - (he_id, node_id)
                - (he_id, node_id, attributes)

        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If any hyperedge or node does not exist, or if any incidence already exists.
        ValueError
            If the structure of any incidence tuple is invalid.
        """
        new_incidences = []
        for pr in incidences:
            if not isinstance(pr, tuple):
                raise ValueError(f"Each incidence must be a tuple, got {type(pr)}")
            if len(pr) == 2:
                he_id, node_id = pr
                attrs = {}
            elif len(pr) == 3:
                he_id, node_id, attrs = pr
                if not isinstance(attrs, dict):
                    raise ValueError(f"Attributes must be a dictionary, got {type(attrs)}")
            else:
                raise ValueError(f"Incidence tuples must be of length 2 or 3, got {len(pr)}")

            if he_id not in self.hyperedges:
                raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
            if node_id not in self.node_properties:
                raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
            if node_id in self.hyperedges[he_id].nodes:
                raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.")

            new_incidences.append((he_id, node_id, attrs.copy()))

        if inplace:
            for he_id, node_id, attrs in new_incidences:
                # Add node to HyperEdge's nodes
                self.hyperedges[he_id].add_node(node_id)
                # Update hyperedge_properties if attributes provided
                if attrs:
                    self.hyperedge_properties[he_id].update(attrs)
                # Add edge in graph with attributes
                self.graph.add_edge(he_id, node_id, **(attrs if attrs else {}))
            return self
        else:
            # Create a new Hypergraph instance with the incidences added
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            for he_id, node_id, attrs in new_incidences:
                # Add node to HyperEdge's nodes
                new_hyperedges[he_id].add_node(node_id)
                # Update hyperedge_properties if attributes provided
                if attrs:
                    new_hyperedge_properties[he_id].update(attrs)
                # Add edge in graph with attributes
                new_graph.add_edge(he_id, node_id, **(attrs if attrs else {}))

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def remove_incidences(
        self,
        incidences: Iterable[Tuple[Any, Any]],
        inplace: bool = True
    ) -> 'Hypergraph':
        """
        Removes the specified incidences from the hypergraph.

        Parameters
        ----------
        incidences : Iterable[Tuple[Any, Any]]
            Incidence identifiers as tuples of (he_id, node_id).
        inplace : bool, default=True
            If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.

        Returns
        -------
        Hypergraph
            The updated or new Hypergraph instance.

        Raises
        ------
        HypergraphError
            If any incidence does not exist.
        """
        incidence_ids = list(incidences)

        # Check existence of incidences
        for he_id, node_id in incidence_ids:
            if he_id not in self.hyperedges:
                raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
            if node_id not in self.node_properties:
                raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
            if node_id not in self.hyperedges[he_id].nodes:
                raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.")

        if inplace:
            for he_id, node_id in incidence_ids:
                # Remove node from HyperEdge's nodes
                self.hyperedges[he_id].remove_node(node_id)
                # Remove edge from graph
                self.graph.remove_edge(he_id, node_id)
            return self
        else:
            # Create a new Hypergraph instance with the incidences removed
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

            for he_id, node_id in incidence_ids:
                # Remove node from HyperEdge's nodes
                new_hyperedges[he_id].remove_node(node_id)
                # Remove edge from graph
                new_graph.remove_edge(he_id, node_id)

            # Reconstruct hyperedges dict for __init__
            hyperedges_dict = {
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            }

            return Hypergraph(
                hyperedges=hyperedges_dict,
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=self.name
            )

    def collapse_duplicate_nodes(
        self,
        name: Optional[str] = None,
        use_uids: Optional[List[Any]] = None,
        use_counts: bool = False,
        return_counts: bool = True,
        return_equivalence_classes: bool = False,
        aggregate_properties_by: Optional[Dict[str, str]] = None,
    ) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:
        """
        Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.

        Parameters
        ----------
        name : Optional[str], default=None
            The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.

        use_uids : Optional[List[Any]] = None
            Specifies the node identifiers to use as representatives for each equivalence class.
            If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.
            If None, the first encountered node in each class is used as the representative.

        use_counts : bool, optional, default=False
            If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').

        return_counts : bool, optional, default=True
            If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.

        return_equivalence_classes : bool, optional, default=False
            If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.

        aggregate_properties_by : Optional[Dict[str, str]] = None
            A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).
            Properties not specified will use the 'first' aggregation.

        Returns
        -------
        Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]
            - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.
            - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.

        Raises
        ------
        HypergraphError
            If the hypergraph is empty or improperly structured.
        """
        if not self.node_properties:
            raise HypergraphError("Cannot collapse nodes in an empty hypergraph.")

        # Identify equivalence classes based on identical hyperedge memberships
        membership_to_nodes: Dict[frozenset, Set[Any]] = {}
        for node_id, node_props in self.node_properties.items():
            key = frozenset(self.get_hyperedges_of_node(node_id))
            membership_to_nodes.setdefault(key, set()).add(node_id)

        # Filter out classes with only one node (no duplicates)
        equivalence_classes = [nodes for nodes in membership_to_nodes.values() if len(nodes) > 1]
        if not equivalence_classes:
            # No duplicates to collapse; return the original hypergraph
            return self if not return_equivalence_classes else (self, {})

        # Prepare aggregation methods
        aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {"weight": "sum"}

        # Initialize mapping from old nodes to new nodes
        node_mapping: Dict[Any, Any] = {}
        equivalence_class_dict: Dict[Any, Set[Any]] = {}

        for eq_class in equivalence_classes:
            # Determine representative
            if use_uids:
                # Select the first UID from use_uids that is in the equivalence class
                representative = next((uid for uid in use_uids if uid in eq_class), None)
                if not representative:
                    # Fallback to the first node in the equivalence class
                    representative = next(iter(eq_class))
            else:
                # Use the first node in the equivalence class as representative
                representative = next(iter(eq_class))

            # Optionally rename with counts
            if use_counts:
                new_representative = f"{representative}:{len(eq_class)}"
            else:
                new_representative = representative

            # Map all nodes in the class to the representative
            for node in eq_class:
                node_mapping[node] = new_representative

            # Store the equivalence class
            equivalence_class_dict[new_representative] = eq_class

        # Replace node IDs in hyperedges based on mapping
        new_hyperedges = {}
        for he_id, hyperedge in self.hyperedges.items():
            new_nodes = set()
            for node_id in hyperedge.nodes:
                new_node_id = node_mapping.get(node_id, node_id)
                new_nodes.add(new_node_id)
            new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties=copy.deepcopy(hyperedge.properties))

        # Aggregate node properties
        new_node_properties = {}
        for node_id, node_props in self.node_properties.items():
            new_node_id = node_mapping.get(node_id, node_id)
            if new_node_id not in new_node_properties:
                new_node_properties[new_node_id] = copy.deepcopy(node_props)
            else:
                for prop, agg_func in aggregate_properties_by.items():
                    if prop in node_props:
                        if agg_func == 'sum':
                            new_node_properties[new_node_id][prop] = new_node_properties[new_node_id].get(prop, 0) + node_props[prop]
                        elif agg_func == 'mean':
                            # To calculate mean, store sum and count
                            if 'sum_' + prop not in new_node_properties[new_node_id]:
                                new_node_properties[new_node_id]['sum_' + prop] = node_props[prop]
                                new_node_properties[new_node_id]['count_' + prop] = 1
                            else:
                                new_node_properties[new_node_id]['sum_' + prop] += node_props[prop]
                                new_node_properties[new_node_id]['count_' + prop] += 1
                            # Calculate mean at the end
                        elif agg_func == 'max':
                            current_max = new_node_properties[new_node_id].get(prop, float('-inf'))
                            new_node_properties[new_node_id][prop] = max(current_max, node_props[prop])
                        elif agg_func == 'min':
                            current_min = new_node_properties[new_node_id].get(prop, float('inf'))
                            new_node_properties[new_node_id][prop] = min(current_min, node_props[prop])
                        else:
                            new_node_properties[new_node_id][prop] = node_props[prop]  # Default to last
        # Finalize mean calculations
        for node_id, props in new_node_properties.items():
            for prop in list(props.keys()):
                if prop.startswith('sum_'):
                    base_prop = prop[4:]
                    sum_val = props[prop]
                    count_val = props.get('count_' + base_prop, 1)
                    new_node_properties[node_id][base_prop] = sum_val / count_val if count_val > 0 else 0
                    del new_node_properties[node_id][prop]
                    del new_node_properties[node_id]['count_' + base_prop]

        # Handle equivalence class size
        if use_counts:
            for node_id in equivalence_class_dict:
                new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])
        elif return_counts:
            for node_id in new_node_properties:
                if node_id in equivalence_class_dict:
                    new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])
                else:
                    new_node_properties[node_id]['equivalence_class_size'] = 1

        # Initialize the collapsed hypergraph
        collapsed_hypergraph = Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            },
            node_properties=new_node_properties,
            hyperedge_properties={
                he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()
            },
            name=name if name else f"{self.name}_collapsed_nodes"
        )

        if return_equivalence_classes:
            return collapsed_hypergraph, equivalence_class_dict
        else:
            return collapsed_hypergraph

    # Analyzing and Querying the Hypergraph

    def get_toplexes(self, return_hypergraph: bool = False) -> Union[List[Any], 'Hypergraph']:
        """
        Computes a maximal collection of toplexes for the hypergraph.
        A :term:`toplex` is a hyperedge that is not contained in any other hyperedge.

        Parameters
        ----------
        return_hypergraph : bool, optional, default=False
            If True, returns a new Hypergraph consisting only of the toplexes.

        Returns
        -------
        List[Any] or Hypergraph
            - A list of toplex hyperedge IDs.
            - If `return_hypergraph=True`, returns a Hypergraph containing only the toplexes.
        """
        toplexes = []
        hyperedges = list(self.hyperedges.values())

        for he in hyperedges:
            if not any(he.nodes < other_he.nodes for other_he in hyperedges if he.id != other_he.id):
                toplexes.append(he.id)

        if return_hypergraph:
            return self.restrict_to_specific_hyperedges(toplexes, name="Toplexes")
        return toplexes

    def is_node_connected(self, s: int = 1) -> bool:
        """
        Determines if the hypergraph is s-node-connected.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.

        Returns
        -------
        bool
            True if the hypergraph is s-node-connected, False otherwise.
        """
        return self._is_connected(s=s, hyperedges=False)

    def is_hyperedge_connected(self, s: int = 1) -> bool:
        """
        Determines if the hypergraph is s-hyperedge-connected.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.

        Returns
        -------
        bool
            True if the hypergraph is s-hyperedge-connected, False otherwise.
        """
        return self._is_connected(s=s, hyperedges=True)

    def _is_connected(self, s: int = 1, hyperedges: bool = False) -> bool:
        """
        Internal method to determine connectivity based on nodes or hyperedges.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.
        hyperedges : bool, optional, default=False
            If True, checks for s-hyperedge-connectedness. Otherwise, checks for s-node-connectedness.

        Returns
        -------
        bool
            Connectivity status.
        """
        if hyperedges:
            # Create hyperedge connectivity graph: hyperedges are nodes, connect if they share >= s nodes
            hyperedge_graph = nx.Graph()
            hyperedge_ids = list(self.hyperedges.keys())
            hyperedge_graph.add_nodes_from(hyperedge_ids)

            for i, he1 in enumerate(hyperedge_ids):
                nodes1 = self.hyperedges[he1].nodes
                for he2 in hyperedge_ids[i+1:]:
                    nodes2 = self.hyperedges[he2].nodes
                    shared_nodes = nodes1 & nodes2
                    if len(shared_nodes) >= s:
                        hyperedge_graph.add_edge(he1, he2)

            try:
                return nx.is_connected(hyperedge_graph)
            except nx.NetworkXPointlessConcept:
                return False
        else:
            # Create node connectivity graph: nodes are nodes, connect if they share >= s hyperedges
            node_graph = nx.Graph()
            node_ids = list(self.node_properties.keys())
            node_graph.add_nodes_from(node_ids)

            for i, node1 in enumerate(node_ids):
                hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}
                for node2 in node_ids[i+1:]:
                    hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}
                    shared_hyperedges = hyperedges1 & hyperedges2
                    if len(shared_hyperedges) >= s:
                        node_graph.add_edge(node1, node2)

            try:
                return nx.is_connected(node_graph)
            except nx.NetworkXPointlessConcept:
                return False

    def get_node_connected_components(
        self, s: int = 1, return_singletons: bool = False
    ) -> Iterator[Set[Any]]:
        """
        Yields the s-node-connected components of the hypergraph.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.
        return_singletons : bool, optional, default=False
            If True, includes singleton components. Otherwise, excludes them.

        Yields
        ------
        Set[Any]
            Sets of node IDs representing each connected component.
        """
        return self.s_connected_components(s=s, hyperedges=False, return_singletons=return_singletons)

    def get_hyperedge_connected_components(
        self, s: int = 1, return_singletons: bool = False
    ) -> Iterator[Set[Any]]:
        """
        Yields the s-hyperedge-connected components of the hypergraph.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.
        return_singletons : bool, optional, default=False
            If True, includes singleton components. Otherwise, excludes them.

        Yields
        ------
        Set[Any]
            Sets of hyperedge IDs representing each connected component.
        """
        return self.s_connected_components(s=s, hyperedges=True, return_singletons=return_singletons)

    def get_node_connected_subgraphs(
        self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None
    ) -> Iterator['Hypergraph']:
        """
        Yields subgraphs corresponding to each s-node-connected component.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.
        return_singletons : bool, optional, default=False
            If True, includes singleton components. Otherwise, excludes them.
        name : Optional[str], default=None
            Base name for the subgraphs. Each subgraph will have a unique name appended.

        Yields
        ------
        Hypergraph
            Subgraphs representing each connected component.
        """
        return self.s_component_subgraphs(
            s=s,
            hyperedges=False,
            return_singletons=return_singletons,
            name=name
        )

    def get_hyperedge_connected_subgraphs(
        self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None
    ) -> Iterator['Hypergraph']:
        """
        Yields subgraphs corresponding to each s-hyperedge-connected component.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.
        return_singletons : bool, optional, default=False
            If True, includes singleton components. Otherwise, excludes them.
        name : Optional[str], default=None
            Base name for the subgraphs. Each subgraph will have a unique name appended.

        Yields
        ------
        Hypergraph
            Subgraphs representing each connected component.
        """
        return self.s_component_subgraphs(
            s=s,
            hyperedges=True,
            return_singletons=return_singletons,
            name=name
        )

    def get_singleton_hyperedges(self) -> List[Any]:
        """
        Returns a list of singleton hyperedges.
        A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.

        Returns
        -------
        List[Any]
            A list of singleton hyperedge IDs.
        """
        singletons = []
        for he in self.hyperedges.values():
            if len(he.nodes) == 1:
                node = next(iter(he.nodes))
                node_degree = sum(1 for hyperedge in self.hyperedges.values() if node in hyperedge.nodes)
                if node_degree == 1:
                    singletons.append(he.id)
        return singletons

    def remove_singleton_hyperedges(self, name: Optional[str] = None) -> 'Hypergraph':
        """
        Constructs a clone of the hypergraph with singleton hyperedges removed.
        """
        singletons = self.get_singleton_hyperedges()
        if not singletons:
            return self.copy(name=name)

        new_hypergraph = self.remove_hyperedges(singletons, inplace=False)
        new_hypergraph.name = name if name else f"{self.name}_no_singleton_hyperedges"
        return new_hypergraph

    def s_connected_components(
        self, 
        s: int = 1, 
        hyperedges: bool = True, 
        return_singletons: bool = False
    ) -> Iterator[Set[Any]]:
        """
        Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.
        hyperedges : bool, optional, default=True
            If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components.
        return_singletons : bool, optional, default=False
            If True, includes singleton components. Otherwise, excludes them.

        Yields
        ------
        Set[Any]
            Sets of hyperedge IDs or node IDs representing each connected component.
        """
        if hyperedges:
            # s-hyperedge-connected: hyperedges are connected if they share at least s nodes
            hyperedge_graph = nx.Graph()
            hyperedge_ids = list(self.hyperedges.keys())
            hyperedge_graph.add_nodes_from(hyperedge_ids)

            for i, he1 in enumerate(hyperedge_ids):
                nodes1 = self.hyperedges[he1].nodes
                for he2 in hyperedge_ids[i + 1:]:
                    nodes2 = self.hyperedges[he2].nodes
                    shared_nodes = nodes1 & nodes2
                    if len(shared_nodes) >= s:
                        hyperedge_graph.add_edge(he1, he2)

            components = nx.connected_components(hyperedge_graph)
            for component in components:
                if not return_singletons and len(component) == 1:
                    continue
                yield component
        else:
            # s-node-connected: nodes are connected if they share at least s hyperedges
            node_graph = nx.Graph()
            node_ids = list(self.node_properties.keys())
            node_graph.add_nodes_from(node_ids)

            for i, node1 in enumerate(node_ids):
                hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}
                for node2 in node_ids[i + 1:]:
                    hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}
                    shared_hyperedges = hyperedges1 & hyperedges2
                    if len(shared_hyperedges) >= s:
                        node_graph.add_edge(node1, node2)

            components = nx.connected_components(node_graph)
            for component in components:
                if not return_singletons and len(component) == 1:
                    continue
                yield component

    def s_component_subgraphs(
        self,
        s: int = 1,
        hyperedges: bool = True,
        return_singletons: bool = False,
        name: Optional[str] = None
    ) -> Iterator['Hypergraph']:
        """
        Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.

        Parameters
        ----------
        s : int, optional, default=1
            The connectivity level to check.
        hyperedges : bool, optional, default=True
            If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components.
        return_singletons : bool, optional, default=False
            If True, includes singleton components. Otherwise, excludes them.
        name : Optional[str], default=None
            Base name for the subgraphs. Each subgraph will have a unique name appended.

        Yields
        ------
        Hypergraph
            Subgraphs representing each connected component.
        """
        for idx, component in enumerate(
            self.s_connected_components(s=s, hyperedges=hyperedges, return_singletons=return_singletons)
        ):
            if hyperedges:
                yield self.restrict_to_specific_hyperedges(
                    hyperedges_to_retain=component, 
                    name=f"{name or self.name}_component_{idx}"
                )
            else:
                yield self.restrict_to_specific_nodes(
                    nodes_to_retain=component, 
                    name=f"{name or self.name}_component_{idx}"
                )

    def compute_node_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:
        """
        Returns the node diameters of the connected components in the hypergraph.

        Parameters
        ----------
        s : int, optional, default=1
            The number of shared hyperedges required for nodes to be considered adjacent.

        Returns
        -------
        Tuple[int, List[int], List[Set[Any]]]
            - Maximum diameter among all connected components.
            - List of diameters for each s-node connected component.
            - List of sets, each containing node IDs in an s-node connected component.

        Raises
        ------
        HypergraphError
            If the hypergraph is not s-connected or has no nodes.
        """
        A, node_id_map = self.adjacency_matrix(s=s, index=True)
        if A is None or A.shape[0] == 0:
            raise HypergraphError("The hypergraph has no nodes to compute diameters.")

        graph = nx.from_scipy_sparse_array(A)

        if not nx.is_connected(graph):
            raise HypergraphError(f"Hypergraph is not s-node-connected. s={s}")

        diams = []
        comps = []
        for component in nx.connected_components(graph):
            subgraph = graph.subgraph(component)
            if len(subgraph) == 1:
                diamc = 0  # Diameter of a single node is 0
            else:
                try:
                    diamc = nx.diameter(subgraph)
                except nx.NetworkXError:
                    diamc = float('inf')  # Infinite diameter if the subgraph is not connected
            diams.append(diamc)
            component_nodes = {node_id_map[node] for node in component}
            comps.append(component_nodes)

        if not diams:
            raise HypergraphError("No connected components found to compute diameters.")

        max_diam = max(diams)
        return max_diam, diams, comps

    def compute_hyperedge_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:
        """
        Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.

        Parameters
        ----------
        s : int, optional, default=1
            The number of shared nodes required for hyperedges to be considered adjacent.

        Returns
        -------
        Tuple[int, List[int], List[Set[Any]]]
            - Maximum diameter among all s-hyperedge-connected components.
            - List of diameters for each s-hyperedge connected component.
            - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.

        Raises
        ------
        HypergraphError
            If the hypergraph is not s-hyperedge-connected or has no hyperedges.
        """
        A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)
        if A is None or A.shape[0] == 0:
            raise HypergraphError("The hypergraph has no hyperedges to compute diameters.")

        graph = nx.from_scipy_sparse_array(A)

        if not nx.is_connected(graph):
            raise HypergraphError(f"Hypergraph is not s-hyperedge-connected. s={s}")

        diams = []
        comps = []
        for component in nx.connected_components(graph):
            subgraph = graph.subgraph(component)
            if len(subgraph) == 1:
                diamc = 0  # Diameter of a single hyperedge is 0
            else:
                try:
                    diamc = nx.diameter(subgraph)
                except nx.NetworkXError:
                    diamc = float('inf')  # Infinite diameter if the subgraph is not connected
            diams.append(diamc)
            component_hyperedges = {he_id_map[he] for he in component}
            comps.append(component_hyperedges)

        if not diams:
            raise HypergraphError("No connected components found to compute hyperedge diameters.")

        max_diam = max(diams)
        return max_diam, diams, comps

    def compute_node_diameter(self, s: int = 1) -> int:
        """
        Returns the diameter of the hypergraph based on s-node connectivity.

        Parameters
        ----------
        s : int, optional, default=1
            The number of shared hyperedges required for nodes to be considered adjacent.

        Returns
        -------
        int
            The diameter of the hypergraph.

        Raises
        ------
        HypergraphError
            If the hypergraph is not s-node-connected or has no nodes.
        """
        A, _ = self.adjacency_matrix(s=s, index=True)
        if A is None or A.shape[0] == 0:
            raise HypergraphError("The hypergraph has no nodes to compute diameter.")

        graph = nx.from_scipy_sparse_array(A)
        if not nx.is_connected(graph):
            raise HypergraphError(f"Hypergraph is not s-node-connected. s={s}")

        try:
            return nx.diameter(graph)
        except nx.NetworkXError as e:
            raise HypergraphError(f"Could not compute diameter: {e}")

    def compute_hyperedge_diameter(self, s: int = 1) -> int:
        """
        Returns the diameter of the hypergraph based on s-hyperedge connectivity.

        Parameters
        ----------
        s : int, optional, default=1
            The number of shared nodes required for hyperedges to be considered adjacent.

        Returns
        -------
        int
            The diameter of the hypergraph based on hyperedge connectivity.

        Raises
        ------
        HypergraphError
            If the hypergraph is not s-hyperedge-connected or has no hyperedges.
        """
        A, _ = self.hyperedge_adjacency_matrix(s=s, index=True)
        if A is None or A.shape[0] == 0:
            raise HypergraphError("The hypergraph has no hyperedges to compute diameter.")

        graph = nx.from_scipy_sparse_array(A)
        if not nx.is_connected(graph):
            raise HypergraphError(f"Hypergraph is not s-hyperedge-connected. s={s}")

        try:
            return nx.diameter(graph)
        except nx.NetworkXError as e:
            raise HypergraphError(f"Could not compute hyperedge diameter: {e}")

    def get_node_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:
        """
        Returns the shortest s-walk distance between two nodes in the hypergraph.

        Parameters
        ----------
        source : Any
            A node identifier in the hypergraph.
        target : Any
            A node identifier in the hypergraph.
        s : int, optional, default=1
            The number of shared hyperedges required for nodes to be considered adjacent.

        Returns
        -------
        Union[int, float]
            The shortest s-walk distance between the source and target nodes.
            Returns `float('inf')` if no path exists.

        Raises
        ------
        HypergraphError
            If either the source or target node does not exist in the hypergraph.
        """
        if source not in self.node_properties:
            raise HypergraphError(f"Source node '{source}' does not exist in the hypergraph.")
        if target not in self.node_properties:
            raise HypergraphError(f"Target node '{target}' does not exist in the hypergraph.")

        A, node_id_map = self.adjacency_matrix(s=s, index=True)
        if A is None:
            raise HypergraphError("Adjacency matrix could not be generated.")

        graph = nx.from_scipy_sparse_array(A)

        try:
            distance = nx.shortest_path_length(graph, source=source, target=target)
            return distance
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            warnings.warn(f"No s-walk path between '{source}' and '{target}'. Returning infinity.")
            return float('inf')

    def get_hyperedge_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:
        """
        Returns the shortest s-walk distance between two hyperedges in the hypergraph.

        Parameters
        ----------
        source : Any
            A hyperedge identifier in the hypergraph.
        target : Any
            A hyperedge identifier in the hypergraph.
        s : int, optional, default=1
            The number of shared nodes required for hyperedges to be considered adjacent.

        Returns
        -------
        Union[int, float]
            The shortest s-walk distance between the source and target hyperedges.
            Returns `float('inf')` if no path exists.

        Raises
        ------
        HypergraphError
            If either the source or target hyperedge does not exist in the hypergraph.
        """
        if source not in self.hyperedges:
            raise HypergraphError(f"Source hyperedge '{source}' does not exist in the hypergraph.")
        if target not in self.hyperedges:
            raise HypergraphError(f"Target hyperedge '{target}' does not exist in the hypergraph.")

        A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)
        if A is None:
            raise HypergraphError("Hyperedge adjacency matrix could not be generated.")

        graph = nx.from_scipy_sparse_array(A)

        try:
            distance = nx.shortest_path_length(graph, source=source, target=target)
            return distance
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            warnings.warn(f"No s-walk path between hyperedges '{source}' and '{target}'. Returning infinity.")
            return float('inf')

    # Advanced Operations and Transformations

    def union(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
        """
        Returns the union of the current hypergraph with another hypergraph.
        The union combines all nodes and hyperedges from both hypergraphs.

        Parameters
        ----------
        other : Hypergraph
            The hypergraph to union with.
        inplace : bool, optional, default=False
            If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance.
        name : Optional[str], default=None
            The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.

        Returns
        -------
        Hypergraph
            The resulting union hypergraph.

        Raises
        ------
        TypeError
            If `other` is not an instance of Hypergraph.
        """
        if not isinstance(other, Hypergraph):
            raise TypeError("The `other` parameter must be an instance of Hypergraph.")

        if inplace:
            # Add nodes from other
            for node_id, props in other.node_properties.items():
                if node_id not in self.node_properties:
                    self.add_node(node_id, properties=props, inplace=True)
                else:
                    # Optionally, merge properties
                    self.node_properties[node_id].update(props)
                    self.graph.nodes[node_id].update(props)

            # Add hyperedges from other
            for he_id, hyperedge in other.hyperedges.items():
                if he_id not in self.hyperedges:
                    self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)
                else:
                    # Optionally, merge properties and nodes
                    self.hyperedges[he_id].nodes.update(hyperedge.nodes)
                    self.hyperedge_properties[he_id].update(hyperedge.properties)
                    for node in hyperedge.nodes:
                        if node not in self.graph:
                            self.add_node(node)
                        self.graph.add_edge(he_id, node)
            if name:
                self.name = name
            return self
        else:
            # Create a new Hypergraph instance
            new_hyperedges = copy.deepcopy(self.hyperedges)
            new_node_properties = copy.deepcopy(self.node_properties)
            new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
            new_graph = copy.deepcopy(self.graph)
            new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)
            new_name = name if name else f"Union_of_{self.name}_{other.name}"

            # Add nodes from other
            for node_id, props in other.node_properties.items():
                if node_id not in new_node_properties:
                    new_node_properties[node_id] = copy.deepcopy(props)
                    new_graph.add_node(node_id, bipartite='node', **props)

            # Add hyperedges from other
            for he_id, hyperedge in other.hyperedges.items():
                if he_id not in new_hyperedges:
                    new_hyperedges[he_id] = copy.deepcopy(hyperedge)
                    new_hyperedge_properties[he_id] = copy.deepcopy(other.hyperedge_properties[he_id])
                    new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])
                    new_bipartite_nodes.add(he_id)
                    for node in hyperedge.nodes:
                        new_graph.add_edge(he_id, node)
                else:
                    # Merge nodes and properties
                    new_hyperedges[he_id].nodes.update(hyperedge.nodes)
                    new_hyperedge_properties[he_id].update(other.hyperedge_properties[he_id])
                    for node in hyperedge.nodes:
                        new_graph.add_edge(he_id, node)

            # Construct the new Hypergraph
            return Hypergraph(
                hyperedges={
                    he_id: {
                        'nodes': list(he.nodes),
                        'properties': he.properties.copy()
                    } for he_id, he in new_hyperedges.items()
                },
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=new_name
            )

    def intersection(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
        """
        Returns the intersection of the current hypergraph with another hypergraph.
        The intersection includes only nodes and hyperedges present in both hypergraphs.

        Parameters
        ----------
        other : Hypergraph
            The hypergraph to intersect with.
        inplace : bool, optional, default=False
            If True, modifies the current hypergraph to keep only the intersecting elements.
            Otherwise, returns a new Hypergraph instance.
        name : Optional[str], default=None
            The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.

        Returns
        -------
        Hypergraph
            The resulting intersection hypergraph.

        Raises
        ------
        TypeError
            If `other` is not an instance of Hypergraph.
        """
        if not isinstance(other, Hypergraph):
            raise TypeError("The `other` parameter must be an instance of Hypergraph.")

        intersect_nodes = set(self.node_properties.keys()) & set(other.node_properties.keys())
        intersect_hyperedges = set(self.hyperedges.keys()) & set(other.hyperedges.keys())

        if inplace:
            # Remove non-intersecting nodes and hyperedges
            nodes_to_remove = set(self.node_properties.keys()) - intersect_nodes
            hyperedges_to_remove = set(self.hyperedges.keys()) - intersect_hyperedges
            self.remove_nodes_from(nodes_to_remove, inplace=True)
            self.remove_hyperedges(hyperedges_to_remove, inplace=True)
            return self
        else:
            # Create a new Hypergraph instance
            new_hyperedges = {}
            new_node_properties = {node_id: copy.deepcopy(self.node_properties[node_id]) for node_id in intersect_nodes}
            new_hyperedge_properties = {}
            new_graph = nx.Graph()
            new_bipartite_nodes = set()

            for he_id in intersect_hyperedges:
                he_self = self.hyperedges[he_id]
                he_other = other.hyperedges[he_id]
                # Intersection hyperedges have the same nodes and merged properties
                new_nodes = set(he_self.nodes) & set(he_other.nodes)
                if not new_nodes:
                    continue  # Skip hyperedges with no common nodes
                new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties={})
                # Merge properties (could define specific rules)
                new_hyperedge_properties[he_id] = {**self.hyperedge_properties.get(he_id, {}), 
                                                   **other.hyperedge_properties.get(he_id, {})}
                new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])
                new_bipartite_nodes.add(he_id)
                for node in new_nodes:
                    new_graph.add_edge(he_id, node)

            return Hypergraph(
                hyperedges={
                    he_id: {
                        'nodes': list(he.nodes),
                        'properties': he.properties.copy()
                    } for he_id, he in new_hyperedges.items()
                },
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=name if name else f"Intersection_of_{self.name}_{other.name}"
            )

    def difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
        """
        Returns the difference of the current hypergraph with another hypergraph.
        The difference includes nodes and hyperedges present in the current hypergraph but not in the other.

        Parameters
        ----------
        other : Hypergraph
            The hypergraph to subtract.
        inplace : bool, optional, default=False
            If True, modifies the current hypergraph by removing elements found in `other`.
            Otherwise, returns a new Hypergraph instance.
        name : Optional[str], default=None
            The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.

        Returns
        -------
        Hypergraph
            The resulting difference hypergraph.

        Raises
        ------
        TypeError
            If `other` is not an instance of Hypergraph.
        """
        if not isinstance(other, Hypergraph):
            raise TypeError("The `other` parameter must be an instance of Hypergraph.")

        if inplace:
            # Remove hyperedges present in other
            hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())
            self.remove_hyperedges(hyperedges_to_remove, inplace=True)
            # Remove nodes present in other
            nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())
            self.remove_nodes_from(nodes_to_remove, inplace=True)
            return self
        else:
            # Create a new Hypergraph instance
            new_hyperedges = {he_id: copy.deepcopy(he) for he_id, he in self.hyperedges.items() if he_id not in other.hyperedges}
            new_hyperedge_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items() if he_id not in other.hyperedges}
            new_node_properties = {node_id: copy.deepcopy(props) for node_id, props in self.node_properties.items() if node_id not in other.node_properties}

            # Reconstruct graph
            new_graph = nx.Graph()
            new_bipartite_nodes = set()
            for he_id, hyperedge in new_hyperedges.items():
                new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])
                new_bipartite_nodes.add(he_id)
                for node in hyperedge.nodes:
                    if node in new_node_properties:
                        new_graph.add_edge(he_id, node)

            return Hypergraph(
                hyperedges={
                    he_id: {
                        'nodes': list(he.nodes),
                        'properties': he.properties.copy()
                    } for he_id, he in new_hyperedges.items()
                },
                node_properties=new_node_properties,
                hyperedge_properties=new_hyperedge_properties,
                name=name if name else f"Difference_of_{self.name}_{other.name}"
            )

    def symmetric_difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
        """
        Returns the symmetric difference of the current hypergraph with another hypergraph.
        The symmetric difference includes elements present in either hypergraph but not in both.

        Parameters
        ----------
        other : Hypergraph
            The hypergraph to symmetric difference with.
        inplace : bool, optional, default=False
            If True, modifies the current hypergraph to keep only the symmetric difference elements.
            Otherwise, returns a new Hypergraph instance.
        name : Optional[str], default=None
            The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.

        Returns
        -------
        Hypergraph
            The resulting symmetric difference hypergraph.

        Raises
        ------
        TypeError
            If `other` is not an instance of Hypergraph.
        """
        if not isinstance(other, Hypergraph):
            raise TypeError("The `other` parameter must be an instance of Hypergraph.")

        if inplace:
            # Hyperedges symmetric difference
            hyperedges_to_add = set(other.hyperedges.keys()) - set(self.hyperedges.keys())
            hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())
            self.remove_hyperedges(hyperedges_to_remove, inplace=True)
            for he_id in hyperedges_to_add:
                hyperedge = other.hyperedges[he_id]
                self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)

            # Nodes symmetric difference
            nodes_to_add = set(other.node_properties.keys()) - set(self.node_properties.keys())
            nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())
            self.remove_nodes_from(nodes_to_remove, inplace=True)
            for node_id in nodes_to_add:
                props = other.node_properties[node_id]
                self.add_node(node_id, properties=props, inplace=True)

            if name:
                self.name = name
            return self
        else:
            # Create a new Hypergraph instance
            union_hg = self.union(other)
            intersection_hg = self.intersection(other)
            return union_hg.difference(intersection_hg, name=name if name else f"SymmetricDifference_of_{self.name}_{other.name}")

    def transpose(self, name: Optional[str] = None) -> 'Hypergraph':
        """
        Transposes the hypergraph by swapping the roles of nodes and hyperedges.
        The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.

        Parameters
        ----------
        name : Optional[str], default=None
            The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.

        Returns
        -------
        Hypergraph
            The transposed hypergraph.
        """
        transposed_hyperedges = {node_id: HyperEdge(id=node_id, nodes=set(), properties=copy.deepcopy(props))
                                 for node_id, props in self.node_properties.items()}
        transposed_node_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items()}

        for he_id, hyperedge in self.hyperedges.items():
            for node in hyperedge.nodes:
                if node in transposed_hyperedges:
                    transposed_hyperedges[node].nodes.add(he_id)

        # Construct the transposed hypergraph
        return Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in transposed_hyperedges.items()
            },
            node_properties=transposed_node_properties,
            hyperedge_properties={he_id: he.properties.copy() for he_id, he in transposed_hyperedges.items()},
            name=name if name else f"{self.name}_transposed"
        )

    def to_bipartite_graph(self, keep_data=False, directed=False) -> nx.Graph:
        """
        Creates a bipartite NetworkX graph from the hypergraph.
        The nodes and hyperedges of the hypergraph become nodes in the bipartite graph.
        For every hyperedge in the hypergraph and each node it connects to, there
        is an edge in the bipartite graph.

        Parameters
        ----------
        keep_data : bool, optional, default = False
            If True, includes the node and hyperedge properties in the NetworkX graph.
        directed : bool, optional, default = False
            If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.

        Returns
        -------
        networkx.Graph or networkx.DiGraph
            The bipartite graph representation of the hypergraph.
        """
        # Choose graph type based on directed flag
        B = nx.DiGraph() if directed else nx.Graph()

        if not keep_data:
            # Add nodes with bipartite attributes, where 0 indicates hyperedges and 1 indicates regular nodes
            B.add_nodes_from(self.hyperedges.keys(), bipartite=0)  # hyperedges
            B.add_nodes_from(self.node_properties.keys(), bipartite=1)  # nodes

            # Add edges between hyperedges and nodes based on hyperedges data
            for he_id, hyperedge in self.hyperedges.items():
                for node in hyperedge.nodes:
                    B.add_edge(he_id, node)
        else:
            # Add nodes with properties if keep_data is True
            for node_id, properties in self.node_properties.items():
                B.add_node(node_id, bipartite=1, **properties)

            for he_id, hyperedge in self.hyperedges.items():
                B.add_node(he_id, bipartite=0, **self.hyperedge_properties.get(he_id, {}))
                for node in hyperedge.nodes:
                    # Add edges with optional properties if keep_data is True
                    B.add_edge(he_id, node)

        return B

    @classmethod
    def from_bipartite_graph(cls, bipartite_graph: nx.Graph, hyperedge_prefix: str = "HE", node_prefix: str = "N", name: Optional[str] = None) -> 'Hypergraph':
        """
        Constructs a Hypergraph instance from a bipartite graph.

        Parameters
        ----------
        bipartite_graph : nx.Graph
            A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes.
        hyperedge_prefix : str, optional, default="HE"
            The prefix to identify hyperedge nodes in the bipartite graph.
        node_prefix : str, optional, default="N"
            The prefix to identify regular nodes in the bipartite graph.
        name : Optional[str], default=None
            The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.

        Returns
        -------
        Hypergraph
            The constructed Hypergraph instance.

        Raises
        ------
        ValueError
            If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.
        """
        hyperedges = {}
        node_properties = {}
        hyperedge_properties = {}
        name = name if name else "FromBipartiteGraph"

        for node in bipartite_graph.nodes(data=True):
            node_id, attrs = node
            if node_id.startswith(hyperedge_prefix):
                # It's a hyperedge
                hyperedges[node_id] = HyperEdge(id=node_id, nodes=set(), properties=attrs)
                hyperedge_properties[node_id] = copy.deepcopy(attrs)
            elif node_id.startswith(node_prefix):
                # It's a regular node
                node_properties[node_id] = copy.deepcopy(attrs)
            else:
                raise ValueError(f"Node '{node_id}' does not start with either hyperedge_prefix '{hyperedge_prefix}' or node_prefix '{node_prefix}'.")

        # Assign nodes to hyperedges based on edges in bipartite graph
        for he_id in hyperedges:
            connected_nodes = set(bipartite_graph.neighbors(he_id))
            hyperedges[he_id].nodes = connected_nodes

        # Construct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in hyperedges.items()
        }

        return cls(
            hyperedges=hyperedges_dict,
            node_properties=node_properties,
            hyperedge_properties=hyperedge_properties,
            name=name
        )
__contains__(item)

Checks if a node is in the hypergraph.

Parameters

item : Any The node identifier to check.

Returns

bool True if the node exists in the hypergraph, False otherwise.

Source code in src/aeiva/hypergraph/hypergraph.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def __contains__(self, item: Any) -> bool:
    """
    Checks if a node is in the hypergraph.

    Parameters
    ----------
    item : Any
        The node identifier to check.

    Returns
    -------
    bool
        True if the node exists in the hypergraph, False otherwise.
    """
    return item in self.node_properties
__eq__(other)

Checks if two hypergraphs are equal based on their hyperedges and nodes.

Parameters

other : Any The other object to compare.

Returns

bool True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.

Source code in src/aeiva/hypergraph/hypergraph.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def __eq__(self, other: Any) -> bool:
    """
    Checks if two hypergraphs are equal based on their hyperedges and nodes.

    Parameters
    ----------
    other : Any
        The other object to compare.

    Returns
    -------
    bool
        True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.
    """
    if not isinstance(other, Hypergraph):
        return False

    # Compare nodes and their properties
    if self.node_properties != other.node_properties:
        return False

    # Compare hyperedges and their properties
    if self.hyperedges.keys() != other.hyperedges.keys():
        return False

    for he_id in self.hyperedges:
        if self.hyperedges[he_id].nodes != other.hyperedges[he_id].nodes:
            return False
        if self.hyperedge_properties.get(he_id, {}) != other.hyperedge_properties.get(he_id, {}):
            return False

    return True
__getitem__(node)

Retrieves the neighbors of a node in the hypergraph.

Neighbors are nodes that share at least one hyperedge with the given node.

Parameters

node : Any The node identifier.

Returns

Iterable[Any] An iterator over neighboring node identifiers.

Raises

HypergraphError If the node does not exist in the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def __getitem__(self, node: Any) -> Iterable[Any]:
    """
    Retrieves the neighbors of a node in the hypergraph.

    Neighbors are nodes that share at least one hyperedge with the given node.

    Parameters
    ----------
    node : Any
        The node identifier.

    Returns
    -------
    Iterable[Any]
        An iterator over neighboring node identifiers.

    Raises
    ------
    HypergraphError
        If the node does not exist in the hypergraph.
    """
    if node not in self.node_properties:
        raise HypergraphError(f"Node '{node}' does not exist in the hypergraph.")

    # Get all hyperedges that include the node
    hyperedges = set(self.graph.neighbors(node))

    # Get all nodes connected by these hyperedges
    neighbors = set()
    for he_id in hyperedges:
        neighbors.update(self.hyperedges[he_id].nodes)

    neighbors.discard(node)  # Remove the node itself
    return neighbors
__iter__()

Allows iteration over the nodes of the hypergraph.

Yields

Any Node identifiers.

Source code in src/aeiva/hypergraph/hypergraph.py
188
189
190
191
192
193
194
195
196
197
def __iter__(self) -> Iterator[Any]:
    """
    Allows iteration over the nodes of the hypergraph.

    Yields
    ------
    Any
        Node identifiers.
    """
    return iter(self.node_properties)
__len__()

Returns the number of nodes in the hypergraph.

Returns

int Number of nodes.

Source code in src/aeiva/hypergraph/hypergraph.py
177
178
179
180
181
182
183
184
185
186
def __len__(self) -> int:
    """
    Returns the number of nodes in the hypergraph.

    Returns
    -------
    int
        Number of nodes.
    """
    return len(self.node_properties)
__repr__()

Official string representation of the hypergraph.

Returns

str A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.

Source code in src/aeiva/hypergraph/hypergraph.py
163
164
165
166
167
168
169
170
171
172
173
174
175
def __repr__(self) -> str:
    """
    Official string representation of the hypergraph.

    Returns
    -------
    str
        A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.
    """
    return (
        f"Hypergraph(name={self.name!r}, "
        f"nodes={len(self)}, hyperedges={len(self.hyperedges)})"
    )
__str__()

String representation of the hypergraph.

Returns

str A string describing the hypergraph with its name, number of nodes, and hyperedges.

Source code in src/aeiva/hypergraph/hypergraph.py
152
153
154
155
156
157
158
159
160
161
def __str__(self) -> str:
    """
    String representation of the hypergraph.

    Returns
    -------
    str
        A string describing the hypergraph with its name, number of nodes, and hyperedges.
    """
    return f"Hypergraph '{self.name}' with {len(self)} nodes and {len(self.hyperedges)} hyperedges."
add_hyperedge(he_id, nodes, properties=None)

Adds a hyperedge to the hypergraph.

Parameters

he_id : Any Unique identifier for the hyperedge. nodes : Iterable[Any] Nodes connected by the hyperedge. properties : Optional[Dict[str, Any]] = None Properties of the hyperedge.

Raises

HypergraphError If the hyperedge ID already exists.

Source code in src/aeiva/hypergraph/hypergraph.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def add_hyperedge(
    self,
    he_id: Any,
    nodes: Iterable[Any],
    properties: Optional[Dict[str, Any]] = None
) -> None:
    """
    Adds a hyperedge to the hypergraph.

    Parameters
    ----------
    he_id : Any
        Unique identifier for the hyperedge.
    nodes : Iterable[Any]
        Nodes connected by the hyperedge.
    properties : Optional[Dict[str, Any]] = None
        Properties of the hyperedge.

    Raises
    ------
    HypergraphError
        If the hyperedge ID already exists.
    """
    if he_id in self.hyperedges:
        raise HypergraphError(f"Hyperedge '{he_id}' already exists.")

    hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)
    self.hyperedges[he_id] = hyperedge
    self.hyperedge_properties[he_id] = copy.deepcopy(properties) if properties else {}

    # Add hyperedge to bipartite graph
    self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties[he_id])
    self.bipartite_nodes.add(he_id)

    # Add edges between hyperedge and nodes
    for node in hyperedge.nodes:
        if node not in self.graph:
            self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))
        self.graph.add_edge(he_id, node)
add_hyperedges_from(hyperedges, inplace=True)

Adds multiple hyperedges with attributes to the hypergraph.

Parameters

hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]] An iterable of hyperedge identifiers or tuples of (he_id, attributes). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If any hyperedge ID already exists. ValueError If any tuple does not contain exactly two elements or if attributes are not dictionaries.

Source code in src/aeiva/hypergraph/hypergraph.py
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def add_hyperedges_from(
    self,
    hyperedges: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],
    inplace: bool = True
) -> 'Hypergraph':
    """
    Adds multiple hyperedges with attributes to the hypergraph.

    Parameters
    ----------
    hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]
        An iterable of hyperedge identifiers or tuples of (he_id, attributes).
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If any hyperedge ID already exists.
    ValueError
        If any tuple does not contain exactly two elements or if attributes are not dictionaries.
    """
    new_hyperedges = []
    for item in hyperedges:
        if isinstance(item, tuple):
            if len(item) != 2 or not isinstance(item[1], dict):
                raise ValueError(f"Each tuple must be of the form (he_id, attributes). Invalid tuple: {item}")
            he_id, attrs = item
        else:
            he_id, attrs = item, {}

        if he_id in self.hyperedges:
            raise HypergraphError(f"Hyperedge '{he_id}' already exists.")

        hyperedge = HyperEdge(id=he_id, nodes=[], properties=attrs.copy())
        new_hyperedges.append(hyperedge)

    if inplace:
        for hyperedge in new_hyperedges:
            self.hyperedges[hyperedge.id] = hyperedge
            self.hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)
            self.graph.add_node(hyperedge.id, bipartite='hyperedge', **self.hyperedge_properties[hyperedge.id])
            self.bipartite_nodes.add(hyperedge.id)
        return self
    else:
        # Create a new Hypergraph instance with added hyperedges
        new_hyperedges_dict = copy.deepcopy(self.hyperedges)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        for hyperedge in new_hyperedges:
            new_hyperedges_dict[hyperedge.id] = hyperedge
            new_hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)
            new_graph.add_node(hyperedge.id, bipartite='hyperedge', **new_hyperedge_properties[hyperedge.id])
            new_bipartite_nodes.add(hyperedge.id)

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges_dict.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
add_incidence(he_id, node_id, attributes=None, inplace=True)

Adds a single incidence with attributes to the hypergraph.

Parameters

he_id : Any Identifier of the hyperedge. node_id : Any Identifier of the node. attributes : Optional[Dict[str, Any]] = None Properties to add to the incidence as key-value pairs. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If the hyperedge or node does not exist, or if the incidence already exists.

Source code in src/aeiva/hypergraph/hypergraph.py
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
def add_incidence(
    self,
    he_id: Any,
    node_id: Any,
    attributes: Optional[Dict[str, Any]] = None,
    inplace: bool = True
) -> 'Hypergraph':
    """
    Adds a single incidence with attributes to the hypergraph.

    Parameters
    ----------
    he_id : Any
        Identifier of the hyperedge.
    node_id : Any
        Identifier of the node.
    attributes : Optional[Dict[str, Any]] = None
        Properties to add to the incidence as key-value pairs.
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If the hyperedge or node does not exist, or if the incidence already exists.
    """
    if he_id not in self.hyperedges:
        raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
    if node_id not in self.node_properties:
        raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
    if node_id in self.hyperedges[he_id].nodes:
        raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.")

    if inplace:
        # Add node to HyperEdge's nodes
        self.hyperedges[he_id].add_node(node_id)
        # Update hyperedge_properties if attributes provided
        if attributes:
            self.hyperedge_properties[he_id].update(attributes)
        # Add edge in graph with attributes
        self.graph.add_edge(he_id, node_id, **(attributes if attributes else {}))
        return self
    else:
        # Create a new Hypergraph instance with the incidence added
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        # Add node to HyperEdge's nodes
        new_hyperedges[he_id].add_node(node_id)
        # Update hyperedge_properties if attributes provided
        if attributes:
            new_hyperedge_properties[he_id].update(attributes)
        # Add edge in graph with attributes
        new_graph.add_edge(he_id, node_id, **(attributes if attributes else {}))

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
add_incidences_from(incidences, inplace=True)

Adds a collection of incidences to the hypergraph.

Parameters

incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]] Incidence tuples as: - (he_id, node_id) - (he_id, node_id, attributes)

bool, default=True

If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If any hyperedge or node does not exist, or if any incidence already exists. ValueError If the structure of any incidence tuple is invalid.

Source code in src/aeiva/hypergraph/hypergraph.py
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
def add_incidences_from(
    self,
    incidences: Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]],
    inplace: bool = True
) -> 'Hypergraph':
    """
    Adds a collection of incidences to the hypergraph.

    Parameters
    ----------
    incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]]
        Incidence tuples as:
            - (he_id, node_id)
            - (he_id, node_id, attributes)

    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If any hyperedge or node does not exist, or if any incidence already exists.
    ValueError
        If the structure of any incidence tuple is invalid.
    """
    new_incidences = []
    for pr in incidences:
        if not isinstance(pr, tuple):
            raise ValueError(f"Each incidence must be a tuple, got {type(pr)}")
        if len(pr) == 2:
            he_id, node_id = pr
            attrs = {}
        elif len(pr) == 3:
            he_id, node_id, attrs = pr
            if not isinstance(attrs, dict):
                raise ValueError(f"Attributes must be a dictionary, got {type(attrs)}")
        else:
            raise ValueError(f"Incidence tuples must be of length 2 or 3, got {len(pr)}")

        if he_id not in self.hyperedges:
            raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
        if node_id not in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
        if node_id in self.hyperedges[he_id].nodes:
            raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.")

        new_incidences.append((he_id, node_id, attrs.copy()))

    if inplace:
        for he_id, node_id, attrs in new_incidences:
            # Add node to HyperEdge's nodes
            self.hyperedges[he_id].add_node(node_id)
            # Update hyperedge_properties if attributes provided
            if attrs:
                self.hyperedge_properties[he_id].update(attrs)
            # Add edge in graph with attributes
            self.graph.add_edge(he_id, node_id, **(attrs if attrs else {}))
        return self
    else:
        # Create a new Hypergraph instance with the incidences added
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        for he_id, node_id, attrs in new_incidences:
            # Add node to HyperEdge's nodes
            new_hyperedges[he_id].add_node(node_id)
            # Update hyperedge_properties if attributes provided
            if attrs:
                new_hyperedge_properties[he_id].update(attrs)
            # Add edge in graph with attributes
            new_graph.add_edge(he_id, node_id, **(attrs if attrs else {}))

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
add_node(node_id, properties=None, inplace=True)

Adds a node to the hypergraph.

Parameters

node_id : Any Identifier for the node. properties : Optional[Dict[str, Any]] = None Properties of the node. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If the node ID already exists.

Source code in src/aeiva/hypergraph/hypergraph.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def add_node(
    self,
    node_id: Any,
    properties: Optional[Dict[str, Any]] = None,
    inplace: bool = True
) -> 'Hypergraph':
    """
    Adds a node to the hypergraph.

    Parameters
    ----------
    node_id : Any
        Identifier for the node.
    properties : Optional[Dict[str, Any]] = None
        Properties of the node.
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If the node ID already exists.
    """
    if node_id in self.node_properties:
        raise HypergraphError(f"Node '{node_id}' already exists in the hypergraph.")

    if inplace:
        self.node_properties[node_id] = copy.deepcopy(properties) if properties else {}
        self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])
        return self
    else:
        # Create a new Hypergraph instance with the added node
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        new_node_properties[node_id] = copy.deepcopy(properties) if properties else {}
        new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])

        return Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            },
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
add_nodes_from(nodes, inplace=True)

Adds multiple nodes with attributes to the hypergraph.

Parameters

nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]] An iterable of node identifiers or tuples of (node_id, attributes). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If any node ID already exists. ValueError If any tuple does not contain exactly two elements or if attributes are not dictionaries.

Source code in src/aeiva/hypergraph/hypergraph.py
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
def add_nodes_from(
    self,
    nodes: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],
    inplace: bool = True
) -> 'Hypergraph':
    """
    Adds multiple nodes with attributes to the hypergraph.

    Parameters
    ----------
    nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]
        An iterable of node identifiers or tuples of (node_id, attributes).
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If any node ID already exists.
    ValueError
        If any tuple does not contain exactly two elements or if attributes are not dictionaries.
    """
    new_nodes = {}
    for item in nodes:
        if isinstance(item, tuple):
            if len(item) != 2 or not isinstance(item[1], dict):
                raise ValueError(f"Each tuple must be of the form (node_id, attributes). Invalid tuple: {item}")
            node_id, attrs = item
        else:
            node_id, attrs = item, {}

        if node_id in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' already exists in the hypergraph.")

        new_nodes[node_id] = copy.deepcopy(attrs)

    if inplace:
        for node_id, attrs in new_nodes.items():
            self.node_properties[node_id] = attrs
            self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])
        return self
    else:
        # Create a new Hypergraph instance with the added nodes
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        for node_id, attrs in new_nodes.items():
            new_node_properties[node_id] = attrs
            new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])

        return Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            },
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
adjacency_matrix(s=1, index=False)

Generates the adjacency matrix for nodes based on s-node connectivity.

Source code in src/aeiva/hypergraph/hypergraph.py
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
def adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:
    """
    Generates the adjacency matrix for nodes based on s-node connectivity.
    """
    from scipy.sparse import lil_matrix

    node_ids = list(self.node_properties.keys())
    node_index = {node_id: idx for idx, node_id in enumerate(node_ids)}
    size = len(node_ids)
    if size == 0:
        return None, {}

    A = lil_matrix((size, size), dtype=int)
    for he in self.hyperedges.values():
        nodes = list(he.nodes)
        for i in range(len(nodes)):
            for j in range(i + 1, len(nodes)):
                A[node_index[nodes[i]], node_index[nodes[j]]] += 1

    # Apply the threshold s and convert to binary
    A = (A >= s).astype(int)
    A = A.tocsr()

    if index:
        return A, node_index
    return A, {}
collapse_duplicate_hyperedges(name=None, use_uids=None, use_counts=False, return_counts=True, return_equivalence_classes=False, aggregate_properties_by=None)

Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.

Parameters

name : Optional[str], default=None The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.

Optional[List[Any]] = None

Specifies the hyperedge identifiers to use as representatives for each equivalence class. If two identifiers occur in the same equivalence class, the first one found in use_uids is used. If None, the first encountered hyperedge in each class is used as the representative.

bool, optional, default=False

If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').

bool, optional, default=True

If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.

bool, optional, default=False

If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.

Optional[Dict[str, str]] = None

A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}). Properties not specified will use the 'first' aggregation.

Returns

Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]] - If return_equivalence_classes=False, returns the new collapsed hypergraph. - If return_equivalence_classes=True, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.

Raises

HypergraphError If the hypergraph is empty or improperly structured.

Source code in src/aeiva/hypergraph/hypergraph.py
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
def collapse_duplicate_hyperedges(
    self,
    name: Optional[str] = None,
    use_uids: Optional[List[Any]] = None,
    use_counts: bool = False,
    return_counts: bool = True,
    return_equivalence_classes: bool = False,
    aggregate_properties_by: Optional[Dict[str, str]] = None,
) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:
    """
    Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.

    Parameters
    ----------
    name : Optional[str], default=None
        The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.

    use_uids : Optional[List[Any]] = None
        Specifies the hyperedge identifiers to use as representatives for each equivalence class.
        If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.
        If None, the first encountered hyperedge in each class is used as the representative.

    use_counts : bool, optional, default=False
        If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').

    return_counts : bool, optional, default=True
        If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.

    return_equivalence_classes : bool, optional, default=False
        If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.

    aggregate_properties_by : Optional[Dict[str, str]] = None
        A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).
        Properties not specified will use the 'first' aggregation.

    Returns
    -------
    Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]
        - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.
        - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.

    Raises
    ------
    HypergraphError
        If the hypergraph is empty or improperly structured.
    """
    if not self.hyperedges:
        raise HypergraphError("Cannot collapse hyperedges in an empty hypergraph.")

    # Identify equivalence classes based on identical node memberships
    membership_to_hyperedges: Dict[frozenset, Set[Any]] = {}
    for he_id, hyperedge in self.hyperedges.items():
        key = frozenset(hyperedge.nodes)
        membership_to_hyperedges.setdefault(key, set()).add(he_id)

    # Filter out classes with only one hyperedge (no duplicates)
    equivalence_classes = [hes for hes in membership_to_hyperedges.values() if len(hes) > 1]
    if not equivalence_classes:
        # No duplicates to collapse; return the original hypergraph
        return self if not return_equivalence_classes else (self, {})

    # Prepare aggregation methods
    aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {"weight": "sum"}

    # Initialize mapping from old hyperedges to new hyperedges
    hyperedge_mapping: Dict[Any, Any] = {}
    equivalence_class_dict: Dict[Any, Set[Any]] = {}

    for eq_class in equivalence_classes:
        # Determine representative
        if use_uids:
            # Select the first UID from use_uids that is in the equivalence class
            representative = next((uid for uid in use_uids if uid in eq_class), None)
            if not representative:
                # Fallback to the first hyperedge in the equivalence class
                representative = next(iter(eq_class))
        else:
            # Use the first hyperedge in the equivalence class as representative
            representative = next(iter(eq_class))

        # Optionally rename with counts
        if use_counts:
            new_representative = f"{representative}:{len(eq_class)}"
        else:
            new_representative = representative

        # Map all hyperedges in the class to the representative
        for he in eq_class:
            hyperedge_mapping[he] = new_representative

        # Store the equivalence class
        equivalence_class_dict[new_representative] = eq_class

    # Replace hyperedge IDs in incidences based on mapping
    new_hyperedges = {}
    for he_id, hyperedge in self.hyperedges.items():
        new_he_id = hyperedge_mapping.get(he_id, he_id)
        if new_he_id not in new_hyperedges:
            new_hyperedges[new_he_id] = HyperEdge(id=new_he_id, nodes=hyperedge.nodes.copy(), properties=copy.deepcopy(hyperedge.properties))
        else:
            new_hyperedges[new_he_id].nodes.update(hyperedge.nodes)

    # Aggregate hyperedge properties
    for he_id, hyperedge in new_hyperedges.items():
        if he_id in equivalence_class_dict:
            aggregated_props = {}
            for prop, agg_func in aggregate_properties_by.items():
                values = [self.hyperedge_properties[old_he].get(prop, 0) for old_he in equivalence_class_dict[he_id]]
                if agg_func == 'sum':
                    aggregated_props[prop] = sum(values)
                elif agg_func == 'mean':
                    aggregated_props[prop] = sum(values) / len(values) if values else 0
                elif agg_func == 'max':
                    aggregated_props[prop] = max(values) if values else None
                elif agg_func == 'min':
                    aggregated_props[prop] = min(values) if values else None
                else:
                    aggregated_props[prop] = values[0] if values else None  # Default to first
            new_hyperedges[he_id].properties.update(aggregated_props)

    # Handle equivalence class size
    if use_counts:
        for he_id in equivalence_class_dict:
            new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])
    elif return_counts:
        for he_id in new_hyperedges:
            if he_id in equivalence_class_dict:
                new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])
            else:
                new_hyperedges[he_id].properties['equivalence_class_size'] = 1

    # Initialize the collapsed hypergraph
    collapsed_hypergraph = Hypergraph(
        hyperedges={
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        },
        node_properties=copy.deepcopy(self.node_properties),
        hyperedge_properties={
            he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()
        },
        name=name if name else f"{self.name}_collapsed_hyperedges"
    )

    if return_equivalence_classes:
        return collapsed_hypergraph, equivalence_class_dict
    else:
        return collapsed_hypergraph
collapse_duplicate_nodes(name=None, use_uids=None, use_counts=False, return_counts=True, return_equivalence_classes=False, aggregate_properties_by=None)

Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.

Parameters

name : Optional[str], default=None The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.

Optional[List[Any]] = None

Specifies the node identifiers to use as representatives for each equivalence class. If two identifiers occur in the same equivalence class, the first one found in use_uids is used. If None, the first encountered node in each class is used as the representative.

bool, optional, default=False

If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').

bool, optional, default=True

If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.

bool, optional, default=False

If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.

Optional[Dict[str, str]] = None

A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}). Properties not specified will use the 'first' aggregation.

Returns

Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]] - If return_equivalence_classes=False, returns the new collapsed hypergraph. - If return_equivalence_classes=True, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.

Raises

HypergraphError If the hypergraph is empty or improperly structured.

Source code in src/aeiva/hypergraph/hypergraph.py
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
def collapse_duplicate_nodes(
    self,
    name: Optional[str] = None,
    use_uids: Optional[List[Any]] = None,
    use_counts: bool = False,
    return_counts: bool = True,
    return_equivalence_classes: bool = False,
    aggregate_properties_by: Optional[Dict[str, str]] = None,
) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:
    """
    Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.

    Parameters
    ----------
    name : Optional[str], default=None
        The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.

    use_uids : Optional[List[Any]] = None
        Specifies the node identifiers to use as representatives for each equivalence class.
        If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.
        If None, the first encountered node in each class is used as the representative.

    use_counts : bool, optional, default=False
        If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').

    return_counts : bool, optional, default=True
        If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.

    return_equivalence_classes : bool, optional, default=False
        If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.

    aggregate_properties_by : Optional[Dict[str, str]] = None
        A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).
        Properties not specified will use the 'first' aggregation.

    Returns
    -------
    Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]
        - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.
        - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.

    Raises
    ------
    HypergraphError
        If the hypergraph is empty or improperly structured.
    """
    if not self.node_properties:
        raise HypergraphError("Cannot collapse nodes in an empty hypergraph.")

    # Identify equivalence classes based on identical hyperedge memberships
    membership_to_nodes: Dict[frozenset, Set[Any]] = {}
    for node_id, node_props in self.node_properties.items():
        key = frozenset(self.get_hyperedges_of_node(node_id))
        membership_to_nodes.setdefault(key, set()).add(node_id)

    # Filter out classes with only one node (no duplicates)
    equivalence_classes = [nodes for nodes in membership_to_nodes.values() if len(nodes) > 1]
    if not equivalence_classes:
        # No duplicates to collapse; return the original hypergraph
        return self if not return_equivalence_classes else (self, {})

    # Prepare aggregation methods
    aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {"weight": "sum"}

    # Initialize mapping from old nodes to new nodes
    node_mapping: Dict[Any, Any] = {}
    equivalence_class_dict: Dict[Any, Set[Any]] = {}

    for eq_class in equivalence_classes:
        # Determine representative
        if use_uids:
            # Select the first UID from use_uids that is in the equivalence class
            representative = next((uid for uid in use_uids if uid in eq_class), None)
            if not representative:
                # Fallback to the first node in the equivalence class
                representative = next(iter(eq_class))
        else:
            # Use the first node in the equivalence class as representative
            representative = next(iter(eq_class))

        # Optionally rename with counts
        if use_counts:
            new_representative = f"{representative}:{len(eq_class)}"
        else:
            new_representative = representative

        # Map all nodes in the class to the representative
        for node in eq_class:
            node_mapping[node] = new_representative

        # Store the equivalence class
        equivalence_class_dict[new_representative] = eq_class

    # Replace node IDs in hyperedges based on mapping
    new_hyperedges = {}
    for he_id, hyperedge in self.hyperedges.items():
        new_nodes = set()
        for node_id in hyperedge.nodes:
            new_node_id = node_mapping.get(node_id, node_id)
            new_nodes.add(new_node_id)
        new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties=copy.deepcopy(hyperedge.properties))

    # Aggregate node properties
    new_node_properties = {}
    for node_id, node_props in self.node_properties.items():
        new_node_id = node_mapping.get(node_id, node_id)
        if new_node_id not in new_node_properties:
            new_node_properties[new_node_id] = copy.deepcopy(node_props)
        else:
            for prop, agg_func in aggregate_properties_by.items():
                if prop in node_props:
                    if agg_func == 'sum':
                        new_node_properties[new_node_id][prop] = new_node_properties[new_node_id].get(prop, 0) + node_props[prop]
                    elif agg_func == 'mean':
                        # To calculate mean, store sum and count
                        if 'sum_' + prop not in new_node_properties[new_node_id]:
                            new_node_properties[new_node_id]['sum_' + prop] = node_props[prop]
                            new_node_properties[new_node_id]['count_' + prop] = 1
                        else:
                            new_node_properties[new_node_id]['sum_' + prop] += node_props[prop]
                            new_node_properties[new_node_id]['count_' + prop] += 1
                        # Calculate mean at the end
                    elif agg_func == 'max':
                        current_max = new_node_properties[new_node_id].get(prop, float('-inf'))
                        new_node_properties[new_node_id][prop] = max(current_max, node_props[prop])
                    elif agg_func == 'min':
                        current_min = new_node_properties[new_node_id].get(prop, float('inf'))
                        new_node_properties[new_node_id][prop] = min(current_min, node_props[prop])
                    else:
                        new_node_properties[new_node_id][prop] = node_props[prop]  # Default to last
    # Finalize mean calculations
    for node_id, props in new_node_properties.items():
        for prop in list(props.keys()):
            if prop.startswith('sum_'):
                base_prop = prop[4:]
                sum_val = props[prop]
                count_val = props.get('count_' + base_prop, 1)
                new_node_properties[node_id][base_prop] = sum_val / count_val if count_val > 0 else 0
                del new_node_properties[node_id][prop]
                del new_node_properties[node_id]['count_' + base_prop]

    # Handle equivalence class size
    if use_counts:
        for node_id in equivalence_class_dict:
            new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])
    elif return_counts:
        for node_id in new_node_properties:
            if node_id in equivalence_class_dict:
                new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])
            else:
                new_node_properties[node_id]['equivalence_class_size'] = 1

    # Initialize the collapsed hypergraph
    collapsed_hypergraph = Hypergraph(
        hyperedges={
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        },
        node_properties=new_node_properties,
        hyperedge_properties={
            he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()
        },
        name=name if name else f"{self.name}_collapsed_nodes"
    )

    if return_equivalence_classes:
        return collapsed_hypergraph, equivalence_class_dict
    else:
        return collapsed_hypergraph
compute_hyperedge_diameter(s=1)

Returns the diameter of the hypergraph based on s-hyperedge connectivity.

Parameters

s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.

Returns

int The diameter of the hypergraph based on hyperedge connectivity.

Raises

HypergraphError If the hypergraph is not s-hyperedge-connected or has no hyperedges.

Source code in src/aeiva/hypergraph/hypergraph.py
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
def compute_hyperedge_diameter(self, s: int = 1) -> int:
    """
    Returns the diameter of the hypergraph based on s-hyperedge connectivity.

    Parameters
    ----------
    s : int, optional, default=1
        The number of shared nodes required for hyperedges to be considered adjacent.

    Returns
    -------
    int
        The diameter of the hypergraph based on hyperedge connectivity.

    Raises
    ------
    HypergraphError
        If the hypergraph is not s-hyperedge-connected or has no hyperedges.
    """
    A, _ = self.hyperedge_adjacency_matrix(s=s, index=True)
    if A is None or A.shape[0] == 0:
        raise HypergraphError("The hypergraph has no hyperedges to compute diameter.")

    graph = nx.from_scipy_sparse_array(A)
    if not nx.is_connected(graph):
        raise HypergraphError(f"Hypergraph is not s-hyperedge-connected. s={s}")

    try:
        return nx.diameter(graph)
    except nx.NetworkXError as e:
        raise HypergraphError(f"Could not compute hyperedge diameter: {e}")
compute_hyperedge_diameters(s=1)

Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.

Parameters

s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.

Returns

Tuple[int, List[int], List[Set[Any]]] - Maximum diameter among all s-hyperedge-connected components. - List of diameters for each s-hyperedge connected component. - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.

Raises

HypergraphError If the hypergraph is not s-hyperedge-connected or has no hyperedges.

Source code in src/aeiva/hypergraph/hypergraph.py
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
def compute_hyperedge_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:
    """
    Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.

    Parameters
    ----------
    s : int, optional, default=1
        The number of shared nodes required for hyperedges to be considered adjacent.

    Returns
    -------
    Tuple[int, List[int], List[Set[Any]]]
        - Maximum diameter among all s-hyperedge-connected components.
        - List of diameters for each s-hyperedge connected component.
        - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.

    Raises
    ------
    HypergraphError
        If the hypergraph is not s-hyperedge-connected or has no hyperedges.
    """
    A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)
    if A is None or A.shape[0] == 0:
        raise HypergraphError("The hypergraph has no hyperedges to compute diameters.")

    graph = nx.from_scipy_sparse_array(A)

    if not nx.is_connected(graph):
        raise HypergraphError(f"Hypergraph is not s-hyperedge-connected. s={s}")

    diams = []
    comps = []
    for component in nx.connected_components(graph):
        subgraph = graph.subgraph(component)
        if len(subgraph) == 1:
            diamc = 0  # Diameter of a single hyperedge is 0
        else:
            try:
                diamc = nx.diameter(subgraph)
            except nx.NetworkXError:
                diamc = float('inf')  # Infinite diameter if the subgraph is not connected
        diams.append(diamc)
        component_hyperedges = {he_id_map[he] for he in component}
        comps.append(component_hyperedges)

    if not diams:
        raise HypergraphError("No connected components found to compute hyperedge diameters.")

    max_diam = max(diams)
    return max_diam, diams, comps
compute_node_diameter(s=1)

Returns the diameter of the hypergraph based on s-node connectivity.

Parameters

s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.

Returns

int The diameter of the hypergraph.

Raises

HypergraphError If the hypergraph is not s-node-connected or has no nodes.

Source code in src/aeiva/hypergraph/hypergraph.py
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
def compute_node_diameter(self, s: int = 1) -> int:
    """
    Returns the diameter of the hypergraph based on s-node connectivity.

    Parameters
    ----------
    s : int, optional, default=1
        The number of shared hyperedges required for nodes to be considered adjacent.

    Returns
    -------
    int
        The diameter of the hypergraph.

    Raises
    ------
    HypergraphError
        If the hypergraph is not s-node-connected or has no nodes.
    """
    A, _ = self.adjacency_matrix(s=s, index=True)
    if A is None or A.shape[0] == 0:
        raise HypergraphError("The hypergraph has no nodes to compute diameter.")

    graph = nx.from_scipy_sparse_array(A)
    if not nx.is_connected(graph):
        raise HypergraphError(f"Hypergraph is not s-node-connected. s={s}")

    try:
        return nx.diameter(graph)
    except nx.NetworkXError as e:
        raise HypergraphError(f"Could not compute diameter: {e}")
compute_node_diameters(s=1)

Returns the node diameters of the connected components in the hypergraph.

Parameters

s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.

Returns

Tuple[int, List[int], List[Set[Any]]] - Maximum diameter among all connected components. - List of diameters for each s-node connected component. - List of sets, each containing node IDs in an s-node connected component.

Raises

HypergraphError If the hypergraph is not s-connected or has no nodes.

Source code in src/aeiva/hypergraph/hypergraph.py
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
def compute_node_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:
    """
    Returns the node diameters of the connected components in the hypergraph.

    Parameters
    ----------
    s : int, optional, default=1
        The number of shared hyperedges required for nodes to be considered adjacent.

    Returns
    -------
    Tuple[int, List[int], List[Set[Any]]]
        - Maximum diameter among all connected components.
        - List of diameters for each s-node connected component.
        - List of sets, each containing node IDs in an s-node connected component.

    Raises
    ------
    HypergraphError
        If the hypergraph is not s-connected or has no nodes.
    """
    A, node_id_map = self.adjacency_matrix(s=s, index=True)
    if A is None or A.shape[0] == 0:
        raise HypergraphError("The hypergraph has no nodes to compute diameters.")

    graph = nx.from_scipy_sparse_array(A)

    if not nx.is_connected(graph):
        raise HypergraphError(f"Hypergraph is not s-node-connected. s={s}")

    diams = []
    comps = []
    for component in nx.connected_components(graph):
        subgraph = graph.subgraph(component)
        if len(subgraph) == 1:
            diamc = 0  # Diameter of a single node is 0
        else:
            try:
                diamc = nx.diameter(subgraph)
            except nx.NetworkXError:
                diamc = float('inf')  # Infinite diameter if the subgraph is not connected
        diams.append(diamc)
        component_nodes = {node_id_map[node] for node in component}
        comps.append(component_nodes)

    if not diams:
        raise HypergraphError("No connected components found to compute diameters.")

    max_diam = max(diams)
    return max_diam, diams, comps
copy(name=None)

Creates a deep copy of the hypergraph instance.

Parameters

name : Optional[str], default=None The name for the copied Hypergraph. If not provided, retains the original name.

Returns

Hypergraph A new Hypergraph instance that is a deep copy of the original.

Source code in src/aeiva/hypergraph/hypergraph.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def copy(self, name: Optional[str] = None) -> 'Hypergraph':
    """
    Creates a deep copy of the hypergraph instance.

    Parameters
    ----------
    name : Optional[str], default=None
        The name for the copied Hypergraph. If not provided, retains the original name.

    Returns
    -------
    Hypergraph
        A new Hypergraph instance that is a deep copy of the original.
    """

    # Deep copy hyperedges
    hyperedges_dict = {}
    for he_id, he in self.hyperedges.items():
        hyperedges_dict[he_id] = {
            'nodes': list(he.nodes),
            'properties': copy.deepcopy(he.properties)
        }

    # Deep copy node_properties and hyperedge_properties
    node_properties_copy = copy.deepcopy(self.node_properties)
    hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)

    # Create a new Hypergraph instance with the copied data
    return Hypergraph(
        hyperedges=hyperedges_dict,
        node_properties=node_properties_copy,
        hyperedge_properties=hyperedge_properties_copy,
        name=name if name is not None else self.name
    )
deepcopy(name=None)

Creates a deep copy of the hypergraph.

Parameters

name : Optional[str], default=None The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.

Returns

Hypergraph A deep copy of the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def deepcopy(self, name: Optional[str] = None) -> 'Hypergraph':
    """
    Creates a deep copy of the hypergraph.

    Parameters
    ----------
    name : Optional[str], default=None
        The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.

    Returns
    -------
    Hypergraph
        A deep copy of the hypergraph.
    """

    # Deep copy hyperedges
    hyperedges_copy = {
        he_id: {
            'nodes': hyperedge.nodes.copy(),
            'properties': copy.deepcopy(hyperedge.properties)
        }
        for he_id, hyperedge in self.hyperedges.items()
    }

    # Deep copy node properties
    node_properties_copy = copy.deepcopy(self.node_properties)

    # Deep copy hyperedge properties
    hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)

    # Set name
    cloned_name = f"{self.name}_deepcopy" if name is None else name

    # Initialize the cloned hypergraph
    cloned_H = Hypergraph(
        hyperedges=hyperedges_copy,
        node_properties=node_properties_copy,
        hyperedge_properties=hyperedge_properties_copy,
        name=cloned_name
    )

    return cloned_H
difference(other, inplace=False, name=None)

Returns the difference of the current hypergraph with another hypergraph. The difference includes nodes and hyperedges present in the current hypergraph but not in the other.

Parameters

other : Hypergraph The hypergraph to subtract. inplace : bool, optional, default=False If True, modifies the current hypergraph by removing elements found in other. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.

Returns

Hypergraph The resulting difference hypergraph.

Raises

TypeError If other is not an instance of Hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
def difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
    """
    Returns the difference of the current hypergraph with another hypergraph.
    The difference includes nodes and hyperedges present in the current hypergraph but not in the other.

    Parameters
    ----------
    other : Hypergraph
        The hypergraph to subtract.
    inplace : bool, optional, default=False
        If True, modifies the current hypergraph by removing elements found in `other`.
        Otherwise, returns a new Hypergraph instance.
    name : Optional[str], default=None
        The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.

    Returns
    -------
    Hypergraph
        The resulting difference hypergraph.

    Raises
    ------
    TypeError
        If `other` is not an instance of Hypergraph.
    """
    if not isinstance(other, Hypergraph):
        raise TypeError("The `other` parameter must be an instance of Hypergraph.")

    if inplace:
        # Remove hyperedges present in other
        hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())
        self.remove_hyperedges(hyperedges_to_remove, inplace=True)
        # Remove nodes present in other
        nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())
        self.remove_nodes_from(nodes_to_remove, inplace=True)
        return self
    else:
        # Create a new Hypergraph instance
        new_hyperedges = {he_id: copy.deepcopy(he) for he_id, he in self.hyperedges.items() if he_id not in other.hyperedges}
        new_hyperedge_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items() if he_id not in other.hyperedges}
        new_node_properties = {node_id: copy.deepcopy(props) for node_id, props in self.node_properties.items() if node_id not in other.node_properties}

        # Reconstruct graph
        new_graph = nx.Graph()
        new_bipartite_nodes = set()
        for he_id, hyperedge in new_hyperedges.items():
            new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])
            new_bipartite_nodes.add(he_id)
            for node in hyperedge.nodes:
                if node in new_node_properties:
                    new_graph.add_edge(he_id, node)

        return Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            },
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=name if name else f"Difference_of_{self.name}_{other.name}"
        )
dual(name=None)

Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.

Parameters

name : Optional[str], default=None Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.

Returns

Hypergraph A new Hypergraph instance representing the dual of the current hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def dual(self, name: Optional[str] = None) -> "Hypergraph":
    """
    Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.

    Parameters
    ----------
    name : Optional[str], default=None
        Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.

    Returns
    -------
    Hypergraph
        A new Hypergraph instance representing the dual of the current hypergraph.
    """
    # Initialize dual hyperedges, which will correspond to original nodes
    dual_hyperedges = {}

    # Invert the node-hyperedge structure
    for he_id, hyperedge in self.hyperedges.items():
        for node in hyperedge.nodes:
            # Each original node becomes a hyperedge in the dual
            if node not in dual_hyperedges:
                dual_hyperedges[node] = {'nodes': [], 'properties': self.node_properties.get(node, {})}
            # The new hyperedge (original node) connects to the original hyperedge id as a "node"
            dual_hyperedges[node]['nodes'].append(he_id)

    # Define node properties in the dual as the original hyperedge properties
    dual_node_properties = {he_id: self.hyperedge_properties.get(he_id, {}) for he_id in self.hyperedges}

    # Create and return the dual Hypergraph
    return Hypergraph(
        hyperedges=dual_hyperedges,
        node_properties=dual_node_properties,
        hyperedge_properties=self.node_properties,  # Properties of original nodes now apply to dual hyperedges
        name=name or (self.name + "_dual" if self.name else "dual")
    )
edge_elements()

Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.

Returns

Dict[Any, List[Any]] Dictionary mapping hyperedge IDs to lists of node IDs they contain.

Source code in src/aeiva/hypergraph/hypergraph.py
141
142
143
144
145
146
147
148
149
150
def edge_elements(self) -> Dict[Any, List[Any]]:
    """
    Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.

    Returns
    -------
    Dict[Any, List[Any]]
        Dictionary mapping hyperedge IDs to lists of node IDs they contain.
    """
    return {he_id: hyperedge.nodes for he_id, hyperedge in self.hyperedges.items()}
edges()

Returns a list of all hyperedge identifiers in the hypergraph.

Returns

List[Any] List of hyperedge IDs.

Source code in src/aeiva/hypergraph/hypergraph.py
130
131
132
133
134
135
136
137
138
139
def edges(self) -> List[Any]:
    """
    Returns a list of all hyperedge identifiers in the hypergraph.

    Returns
    -------
    List[Any]
        List of hyperedge IDs.
    """
    return list(self.hyperedges.keys())
from_bipartite_graph(bipartite_graph, hyperedge_prefix='HE', node_prefix='N', name=None) classmethod

Constructs a Hypergraph instance from a bipartite graph.

Parameters

bipartite_graph : nx.Graph A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes. hyperedge_prefix : str, optional, default="HE" The prefix to identify hyperedge nodes in the bipartite graph. node_prefix : str, optional, default="N" The prefix to identify regular nodes in the bipartite graph. name : Optional[str], default=None The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.

Returns

Hypergraph The constructed Hypergraph instance.

Raises

ValueError If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.

Source code in src/aeiva/hypergraph/hypergraph.py
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
@classmethod
def from_bipartite_graph(cls, bipartite_graph: nx.Graph, hyperedge_prefix: str = "HE", node_prefix: str = "N", name: Optional[str] = None) -> 'Hypergraph':
    """
    Constructs a Hypergraph instance from a bipartite graph.

    Parameters
    ----------
    bipartite_graph : nx.Graph
        A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes.
    hyperedge_prefix : str, optional, default="HE"
        The prefix to identify hyperedge nodes in the bipartite graph.
    node_prefix : str, optional, default="N"
        The prefix to identify regular nodes in the bipartite graph.
    name : Optional[str], default=None
        The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.

    Returns
    -------
    Hypergraph
        The constructed Hypergraph instance.

    Raises
    ------
    ValueError
        If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.
    """
    hyperedges = {}
    node_properties = {}
    hyperedge_properties = {}
    name = name if name else "FromBipartiteGraph"

    for node in bipartite_graph.nodes(data=True):
        node_id, attrs = node
        if node_id.startswith(hyperedge_prefix):
            # It's a hyperedge
            hyperedges[node_id] = HyperEdge(id=node_id, nodes=set(), properties=attrs)
            hyperedge_properties[node_id] = copy.deepcopy(attrs)
        elif node_id.startswith(node_prefix):
            # It's a regular node
            node_properties[node_id] = copy.deepcopy(attrs)
        else:
            raise ValueError(f"Node '{node_id}' does not start with either hyperedge_prefix '{hyperedge_prefix}' or node_prefix '{node_prefix}'.")

    # Assign nodes to hyperedges based on edges in bipartite graph
    for he_id in hyperedges:
        connected_nodes = set(bipartite_graph.neighbors(he_id))
        hyperedges[he_id].nodes = connected_nodes

    # Construct hyperedges dict for __init__
    hyperedges_dict = {
        he_id: {
            'nodes': list(he.nodes),
            'properties': he.properties.copy()
        } for he_id, he in hyperedges.items()
    }

    return cls(
        hyperedges=hyperedges_dict,
        node_properties=node_properties,
        hyperedge_properties=hyperedge_properties,
        name=name
    )
get_hyperedge_connected_components(s=1, return_singletons=False)

Yields the s-hyperedge-connected components of the hypergraph.

Parameters

s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.

Yields

Set[Any] Sets of hyperedge IDs representing each connected component.

Source code in src/aeiva/hypergraph/hypergraph.py
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
def get_hyperedge_connected_components(
    self, s: int = 1, return_singletons: bool = False
) -> Iterator[Set[Any]]:
    """
    Yields the s-hyperedge-connected components of the hypergraph.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.
    return_singletons : bool, optional, default=False
        If True, includes singleton components. Otherwise, excludes them.

    Yields
    ------
    Set[Any]
        Sets of hyperedge IDs representing each connected component.
    """
    return self.s_connected_components(s=s, hyperedges=True, return_singletons=return_singletons)
get_hyperedge_connected_subgraphs(s=1, return_singletons=False, name=None)

Yields subgraphs corresponding to each s-hyperedge-connected component.

Parameters

s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.

Yields

Hypergraph Subgraphs representing each connected component.

Source code in src/aeiva/hypergraph/hypergraph.py
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
def get_hyperedge_connected_subgraphs(
    self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None
) -> Iterator['Hypergraph']:
    """
    Yields subgraphs corresponding to each s-hyperedge-connected component.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.
    return_singletons : bool, optional, default=False
        If True, includes singleton components. Otherwise, excludes them.
    name : Optional[str], default=None
        Base name for the subgraphs. Each subgraph will have a unique name appended.

    Yields
    ------
    Hypergraph
        Subgraphs representing each connected component.
    """
    return self.s_component_subgraphs(
        s=s,
        hyperedges=True,
        return_singletons=return_singletons,
        name=name
    )
get_hyperedge_distance(source, target, s=1)

Returns the shortest s-walk distance between two hyperedges in the hypergraph.

Parameters

source : Any A hyperedge identifier in the hypergraph. target : Any A hyperedge identifier in the hypergraph. s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.

Returns

Union[int, float] The shortest s-walk distance between the source and target hyperedges. Returns float('inf') if no path exists.

Raises

HypergraphError If either the source or target hyperedge does not exist in the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
def get_hyperedge_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:
    """
    Returns the shortest s-walk distance between two hyperedges in the hypergraph.

    Parameters
    ----------
    source : Any
        A hyperedge identifier in the hypergraph.
    target : Any
        A hyperedge identifier in the hypergraph.
    s : int, optional, default=1
        The number of shared nodes required for hyperedges to be considered adjacent.

    Returns
    -------
    Union[int, float]
        The shortest s-walk distance between the source and target hyperedges.
        Returns `float('inf')` if no path exists.

    Raises
    ------
    HypergraphError
        If either the source or target hyperedge does not exist in the hypergraph.
    """
    if source not in self.hyperedges:
        raise HypergraphError(f"Source hyperedge '{source}' does not exist in the hypergraph.")
    if target not in self.hyperedges:
        raise HypergraphError(f"Target hyperedge '{target}' does not exist in the hypergraph.")

    A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)
    if A is None:
        raise HypergraphError("Hyperedge adjacency matrix could not be generated.")

    graph = nx.from_scipy_sparse_array(A)

    try:
        distance = nx.shortest_path_length(graph, source=source, target=target)
        return distance
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        warnings.warn(f"No s-walk path between hyperedges '{source}' and '{target}'. Returning infinity.")
        return float('inf')
get_hyperedges_of_node(node_id)

Retrieves all hyperedges that a given node is part of.

Parameters

node_id : Any The node identifier.

Returns

Set[Any] A set of hyperedge IDs that the node belongs to.

Raises

HypergraphError If the node does not exist in the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
def get_hyperedges_of_node(self, node_id: Any) -> Set[Any]:
    """
    Retrieves all hyperedges that a given node is part of.

    Parameters
    ----------
    node_id : Any
        The node identifier.

    Returns
    -------
    Set[Any]
        A set of hyperedge IDs that the node belongs to.

    Raises
    ------
    HypergraphError
        If the node does not exist in the hypergraph.
    """
    if node_id not in self.node_properties:
        raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
    return {he.id for he in self.hyperedges.values() if node_id in he.nodes}
get_node_connected_components(s=1, return_singletons=False)

Yields the s-node-connected components of the hypergraph.

Parameters

s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.

Yields

Set[Any] Sets of node IDs representing each connected component.

Source code in src/aeiva/hypergraph/hypergraph.py
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
def get_node_connected_components(
    self, s: int = 1, return_singletons: bool = False
) -> Iterator[Set[Any]]:
    """
    Yields the s-node-connected components of the hypergraph.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.
    return_singletons : bool, optional, default=False
        If True, includes singleton components. Otherwise, excludes them.

    Yields
    ------
    Set[Any]
        Sets of node IDs representing each connected component.
    """
    return self.s_connected_components(s=s, hyperedges=False, return_singletons=return_singletons)
get_node_connected_subgraphs(s=1, return_singletons=False, name=None)

Yields subgraphs corresponding to each s-node-connected component.

Parameters

s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.

Yields

Hypergraph Subgraphs representing each connected component.

Source code in src/aeiva/hypergraph/hypergraph.py
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
def get_node_connected_subgraphs(
    self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None
) -> Iterator['Hypergraph']:
    """
    Yields subgraphs corresponding to each s-node-connected component.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.
    return_singletons : bool, optional, default=False
        If True, includes singleton components. Otherwise, excludes them.
    name : Optional[str], default=None
        Base name for the subgraphs. Each subgraph will have a unique name appended.

    Yields
    ------
    Hypergraph
        Subgraphs representing each connected component.
    """
    return self.s_component_subgraphs(
        s=s,
        hyperedges=False,
        return_singletons=return_singletons,
        name=name
    )
get_node_distance(source, target, s=1)

Returns the shortest s-walk distance between two nodes in the hypergraph.

Parameters

source : Any A node identifier in the hypergraph. target : Any A node identifier in the hypergraph. s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.

Returns

Union[int, float] The shortest s-walk distance between the source and target nodes. Returns float('inf') if no path exists.

Raises

HypergraphError If either the source or target node does not exist in the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
def get_node_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:
    """
    Returns the shortest s-walk distance between two nodes in the hypergraph.

    Parameters
    ----------
    source : Any
        A node identifier in the hypergraph.
    target : Any
        A node identifier in the hypergraph.
    s : int, optional, default=1
        The number of shared hyperedges required for nodes to be considered adjacent.

    Returns
    -------
    Union[int, float]
        The shortest s-walk distance between the source and target nodes.
        Returns `float('inf')` if no path exists.

    Raises
    ------
    HypergraphError
        If either the source or target node does not exist in the hypergraph.
    """
    if source not in self.node_properties:
        raise HypergraphError(f"Source node '{source}' does not exist in the hypergraph.")
    if target not in self.node_properties:
        raise HypergraphError(f"Target node '{target}' does not exist in the hypergraph.")

    A, node_id_map = self.adjacency_matrix(s=s, index=True)
    if A is None:
        raise HypergraphError("Adjacency matrix could not be generated.")

    graph = nx.from_scipy_sparse_array(A)

    try:
        distance = nx.shortest_path_length(graph, source=source, target=target)
        return distance
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        warnings.warn(f"No s-walk path between '{source}' and '{target}'. Returning infinity.")
        return float('inf')
get_singleton_hyperedges()

Returns a list of singleton hyperedges. A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.

Returns

List[Any] A list of singleton hyperedge IDs.

Source code in src/aeiva/hypergraph/hypergraph.py
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
def get_singleton_hyperedges(self) -> List[Any]:
    """
    Returns a list of singleton hyperedges.
    A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.

    Returns
    -------
    List[Any]
        A list of singleton hyperedge IDs.
    """
    singletons = []
    for he in self.hyperedges.values():
        if len(he.nodes) == 1:
            node = next(iter(he.nodes))
            node_degree = sum(1 for hyperedge in self.hyperedges.values() if node in hyperedge.nodes)
            if node_degree == 1:
                singletons.append(he.id)
    return singletons
get_toplexes(return_hypergraph=False)

Computes a maximal collection of toplexes for the hypergraph. A :term:toplex is a hyperedge that is not contained in any other hyperedge.

Parameters

return_hypergraph : bool, optional, default=False If True, returns a new Hypergraph consisting only of the toplexes.

Returns

List[Any] or Hypergraph - A list of toplex hyperedge IDs. - If return_hypergraph=True, returns a Hypergraph containing only the toplexes.

Source code in src/aeiva/hypergraph/hypergraph.py
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
def get_toplexes(self, return_hypergraph: bool = False) -> Union[List[Any], 'Hypergraph']:
    """
    Computes a maximal collection of toplexes for the hypergraph.
    A :term:`toplex` is a hyperedge that is not contained in any other hyperedge.

    Parameters
    ----------
    return_hypergraph : bool, optional, default=False
        If True, returns a new Hypergraph consisting only of the toplexes.

    Returns
    -------
    List[Any] or Hypergraph
        - A list of toplex hyperedge IDs.
        - If `return_hypergraph=True`, returns a Hypergraph containing only the toplexes.
    """
    toplexes = []
    hyperedges = list(self.hyperedges.values())

    for he in hyperedges:
        if not any(he.nodes < other_he.nodes for other_he in hyperedges if he.id != other_he.id):
            toplexes.append(he.id)

    if return_hypergraph:
        return self.restrict_to_specific_hyperedges(toplexes, name="Toplexes")
    return toplexes
hyperedge_adjacency_matrix(s=1, index=False)

Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.

Parameters

s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent. index : bool, optional, default=False If True, returns a mapping from matrix indices to hyperedge IDs.

Returns

Tuple[Optional[csr_matrix], Dict[int, Any]] - The adjacency matrix in CSR format. - A dictionary mapping matrix indices to hyperedge IDs.

Source code in src/aeiva/hypergraph/hypergraph.py
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
def hyperedge_adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:
    """
    Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.

    Parameters
    ----------
    s : int, optional, default=1
        The number of shared nodes required for hyperedges to be considered adjacent.
    index : bool, optional, default=False
        If True, returns a mapping from matrix indices to hyperedge IDs.

    Returns
    -------
    Tuple[Optional[csr_matrix], Dict[int, Any]]
        - The adjacency matrix in CSR format.
        - A dictionary mapping matrix indices to hyperedge IDs.
    """
    from scipy.sparse import lil_matrix

    hyperedge_ids = list(self.hyperedges.keys())
    he_index = {he_id: idx for idx, he_id in enumerate(hyperedge_ids)}
    size = len(hyperedge_ids)
    if size == 0:
        return None, {}

    A = lil_matrix((size, size), dtype=int)
    for i, he1 in enumerate(hyperedge_ids):
        nodes1 = self.hyperedges[he1].nodes
        for j in range(i + 1, size):
            he2 = hyperedge_ids[j]
            nodes2 = self.hyperedges[he2].nodes
            shared_nodes = nodes1 & nodes2
            if len(shared_nodes) >= s:
                A[i, j] = 1
                A[j, i] = 1

    A = A.tocsr()

    if index:
        return A, he_index
    return A, {}
intersection(other, inplace=False, name=None)

Returns the intersection of the current hypergraph with another hypergraph. The intersection includes only nodes and hyperedges present in both hypergraphs.

Parameters

other : Hypergraph The hypergraph to intersect with. inplace : bool, optional, default=False If True, modifies the current hypergraph to keep only the intersecting elements. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.

Returns

Hypergraph The resulting intersection hypergraph.

Raises

TypeError If other is not an instance of Hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
def intersection(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
    """
    Returns the intersection of the current hypergraph with another hypergraph.
    The intersection includes only nodes and hyperedges present in both hypergraphs.

    Parameters
    ----------
    other : Hypergraph
        The hypergraph to intersect with.
    inplace : bool, optional, default=False
        If True, modifies the current hypergraph to keep only the intersecting elements.
        Otherwise, returns a new Hypergraph instance.
    name : Optional[str], default=None
        The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.

    Returns
    -------
    Hypergraph
        The resulting intersection hypergraph.

    Raises
    ------
    TypeError
        If `other` is not an instance of Hypergraph.
    """
    if not isinstance(other, Hypergraph):
        raise TypeError("The `other` parameter must be an instance of Hypergraph.")

    intersect_nodes = set(self.node_properties.keys()) & set(other.node_properties.keys())
    intersect_hyperedges = set(self.hyperedges.keys()) & set(other.hyperedges.keys())

    if inplace:
        # Remove non-intersecting nodes and hyperedges
        nodes_to_remove = set(self.node_properties.keys()) - intersect_nodes
        hyperedges_to_remove = set(self.hyperedges.keys()) - intersect_hyperedges
        self.remove_nodes_from(nodes_to_remove, inplace=True)
        self.remove_hyperedges(hyperedges_to_remove, inplace=True)
        return self
    else:
        # Create a new Hypergraph instance
        new_hyperedges = {}
        new_node_properties = {node_id: copy.deepcopy(self.node_properties[node_id]) for node_id in intersect_nodes}
        new_hyperedge_properties = {}
        new_graph = nx.Graph()
        new_bipartite_nodes = set()

        for he_id in intersect_hyperedges:
            he_self = self.hyperedges[he_id]
            he_other = other.hyperedges[he_id]
            # Intersection hyperedges have the same nodes and merged properties
            new_nodes = set(he_self.nodes) & set(he_other.nodes)
            if not new_nodes:
                continue  # Skip hyperedges with no common nodes
            new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties={})
            # Merge properties (could define specific rules)
            new_hyperedge_properties[he_id] = {**self.hyperedge_properties.get(he_id, {}), 
                                               **other.hyperedge_properties.get(he_id, {})}
            new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])
            new_bipartite_nodes.add(he_id)
            for node in new_nodes:
                new_graph.add_edge(he_id, node)

        return Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            },
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=name if name else f"Intersection_of_{self.name}_{other.name}"
        )
is_hyperedge_connected(s=1)

Determines if the hypergraph is s-hyperedge-connected.

Parameters

s : int, optional, default=1 The connectivity level to check.

Returns

bool True if the hypergraph is s-hyperedge-connected, False otherwise.

Source code in src/aeiva/hypergraph/hypergraph.py
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
def is_hyperedge_connected(self, s: int = 1) -> bool:
    """
    Determines if the hypergraph is s-hyperedge-connected.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.

    Returns
    -------
    bool
        True if the hypergraph is s-hyperedge-connected, False otherwise.
    """
    return self._is_connected(s=s, hyperedges=True)
is_node_connected(s=1)

Determines if the hypergraph is s-node-connected.

Parameters

s : int, optional, default=1 The connectivity level to check.

Returns

bool True if the hypergraph is s-node-connected, False otherwise.

Source code in src/aeiva/hypergraph/hypergraph.py
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
def is_node_connected(self, s: int = 1) -> bool:
    """
    Determines if the hypergraph is s-node-connected.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.

    Returns
    -------
    bool
        True if the hypergraph is s-node-connected, False otherwise.
    """
    return self._is_connected(s=s, hyperedges=False)
node_memberships()

Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.

Returns

Dict[Any, List[Any]] Dictionary mapping node IDs to the hyperedge IDs they belong to.

Source code in src/aeiva/hypergraph/hypergraph.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def node_memberships(self) -> Dict[Any, List[Any]]:
    """
    Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.

    Returns
    -------
    Dict[Any, List[Any]]
        Dictionary mapping node IDs to the hyperedge IDs they belong to.
    """
    memberships = {}
    for he_id, hyperedge in self.hyperedges.items():
        for node in hyperedge.nodes:
            memberships.setdefault(node, []).append(he_id)
    return memberships
nodes()

Returns a list of all unique node identifiers in the hypergraph.

Returns

List[Any] List of node IDs.

Source code in src/aeiva/hypergraph/hypergraph.py
104
105
106
107
108
109
110
111
112
113
def nodes(self) -> List[Any]:
    """
    Returns a list of all unique node identifiers in the hypergraph.

    Returns
    -------
    List[Any]
        List of node IDs.
    """
    return list(self.node_properties.keys())
remove_hyperedge(he_id)

Removes a hyperedge from the hypergraph.

Parameters

he_id : Any Identifier of the hyperedge to remove.

Raises

HypergraphError If the hyperedge does not exist.

Source code in src/aeiva/hypergraph/hypergraph.py
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def remove_hyperedge(self, he_id: Any) -> None:
    """
    Removes a hyperedge from the hypergraph.

    Parameters
    ----------
    he_id : Any
        Identifier of the hyperedge to remove.

    Raises
    ------
    HypergraphError
        If the hyperedge does not exist.
    """
    if he_id not in self.hyperedges:
        raise HypergraphError(f"Hyperedge '{he_id}' does not exist.")

    # Remove hyperedge from the graph, which also removes all incidences
    self.graph.remove_node(he_id)
    self.bipartite_nodes.discard(he_id)

    # Remove from internal structures
    del self.hyperedges[he_id]
    self.hyperedge_properties.pop(he_id, None)
remove_hyperedges(he_ids, inplace=True)

Removes the specified hyperedges from the hypergraph.

Parameters

he_ids : Any | Iterable[Any] Hyperedge identifier(s) to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If any hyperedge ID does not exist.

Source code in src/aeiva/hypergraph/hypergraph.py
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
def remove_hyperedges(self, he_ids: Union[Any, Iterable[Any]], inplace: bool = True) -> 'Hypergraph':
    """
    Removes the specified hyperedges from the hypergraph.

    Parameters
    ----------
    he_ids : Any | Iterable[Any]
        Hyperedge identifier(s) to remove.
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If any hyperedge ID does not exist.
    """
    if isinstance(he_ids, (str, int)):
        he_ids = [he_ids]
    else:
        he_ids = list(he_ids)

    non_existing = set(he_ids) - set(self.hyperedges.keys())
    if non_existing:
        raise HypergraphError(f"Hyperedges {non_existing} do not exist in the hypergraph.")

    if inplace:
        for he_id in he_ids:
            self.remove_hyperedge(he_id)
        return self
    else:
        # Create a new Hypergraph instance with hyperedges removed
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        for he_id in he_ids:
            del new_hyperedges[he_id]
            new_hyperedge_properties.pop(he_id, None)
            new_graph.remove_node(he_id)
            new_bipartite_nodes.discard(he_id)

        # Remove nodes not connected to any hyperedges
        retained_nodes = set()
        for hyperedge in new_hyperedges.values():
            retained_nodes.update(hyperedge.nodes)

        new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
remove_incidence(he_id, node_id, inplace=True)

Removes a single incidence from the hypergraph.

Parameters

he_id : Any Identifier of the hyperedge. node_id : Any Identifier of the node. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If the hyperedge or node does not exist, or if the incidence does not exist.

Source code in src/aeiva/hypergraph/hypergraph.py
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
def remove_incidence(
    self,
    he_id: Any,
    node_id: Any,
    inplace: bool = True
) -> 'Hypergraph':
    """
    Removes a single incidence from the hypergraph.

    Parameters
    ----------
    he_id : Any
        Identifier of the hyperedge.
    node_id : Any
        Identifier of the node.
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If the hyperedge or node does not exist, or if the incidence does not exist.
    """
    if he_id not in self.hyperedges:
        raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
    if node_id not in self.node_properties:
        raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
    if node_id not in self.hyperedges[he_id].nodes:
        raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.")

    if inplace:
        # Remove node from HyperEdge's nodes
        self.hyperedges[he_id].remove_node(node_id)
        # Remove edge from graph
        self.graph.remove_edge(he_id, node_id)
        return self
    else:
        # Create a new Hypergraph instance with the incidence removed
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        # Remove node from HyperEdge's nodes
        new_hyperedges[he_id].remove_node(node_id)
        # Remove edge from graph
        new_graph.remove_edge(he_id, node_id)

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
remove_incidences(incidences, inplace=True)

Removes the specified incidences from the hypergraph.

Parameters

incidences : Iterable[Tuple[Any, Any]] Incidence identifiers as tuples of (he_id, node_id). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If any incidence does not exist.

Source code in src/aeiva/hypergraph/hypergraph.py
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
def remove_incidences(
    self,
    incidences: Iterable[Tuple[Any, Any]],
    inplace: bool = True
) -> 'Hypergraph':
    """
    Removes the specified incidences from the hypergraph.

    Parameters
    ----------
    incidences : Iterable[Tuple[Any, Any]]
        Incidence identifiers as tuples of (he_id, node_id).
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If any incidence does not exist.
    """
    incidence_ids = list(incidences)

    # Check existence of incidences
    for he_id, node_id in incidence_ids:
        if he_id not in self.hyperedges:
            raise HypergraphError(f"Hyperedge '{he_id}' does not exist in the hypergraph.")
        if node_id not in self.node_properties:
            raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")
        if node_id not in self.hyperedges[he_id].nodes:
            raise HypergraphError(f"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.")

    if inplace:
        for he_id, node_id in incidence_ids:
            # Remove node from HyperEdge's nodes
            self.hyperedges[he_id].remove_node(node_id)
            # Remove edge from graph
            self.graph.remove_edge(he_id, node_id)
        return self
    else:
        # Create a new Hypergraph instance with the incidences removed
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        for he_id, node_id in incidence_ids:
            # Remove node from HyperEdge's nodes
            new_hyperedges[he_id].remove_node(node_id)
            # Remove edge from graph
            new_graph.remove_edge(he_id, node_id)

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
remove_node(node_id, inplace=True)

Removes a node from the hypergraph.

Parameters

node_id : Any Identifier of the node to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If the node does not exist.

Source code in src/aeiva/hypergraph/hypergraph.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def remove_node(self, node_id: Any, inplace: bool = True) -> 'Hypergraph':
    """
    Removes a node from the hypergraph.

    Parameters
    ----------
    node_id : Any
        Identifier of the node to remove.
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If the node does not exist.
    """
    if node_id not in self.node_properties:
        raise HypergraphError(f"Node '{node_id}' does not exist in the hypergraph.")

    if inplace:
        # Remove node from node_properties
        del self.node_properties[node_id]
        # Remove node from all hyperedges
        for hyperedge in self.hyperedges.values():
            if node_id in hyperedge.nodes:
                hyperedge.remove_node(node_id)
        # Remove node from graph, which also removes all incidences
        self.graph.remove_node(node_id)
        return self
    else:
        # Create a new Hypergraph instance with the node removed
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        # Remove node from node_properties
        del new_node_properties[node_id]
        # Remove node from all hyperedges
        for hyperedge in new_hyperedges.values():
            if node_id in hyperedge.nodes:
                hyperedge.remove_node(node_id)
        # Remove node from graph, which also removes all incidences
        new_graph.remove_node(node_id)

        # Remove nodes not connected to any hyperedges
        retained_nodes = set()
        for hyperedge in new_hyperedges.values():
            retained_nodes.update(hyperedge.nodes)

        new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
remove_nodes_from(nodes, inplace=True)

Removes the specified nodes from the hypergraph.

Parameters

nodes : Any | Iterable[Any] Node identifier(s) to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.

Returns

Hypergraph The updated or new Hypergraph instance.

Raises

HypergraphError If any node ID does not exist.

Source code in src/aeiva/hypergraph/hypergraph.py
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def remove_nodes_from(
    self,
    nodes: Union[Any, Iterable[Any]],
    inplace: bool = True
) -> 'Hypergraph':
    """
    Removes the specified nodes from the hypergraph.

    Parameters
    ----------
    nodes : Any | Iterable[Any]
        Node identifier(s) to remove.
    inplace : bool, default=True
        If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.

    Returns
    -------
    Hypergraph
        The updated or new Hypergraph instance.

    Raises
    ------
    HypergraphError
        If any node ID does not exist.
    """
    if isinstance(nodes, (str, int)):
        nodes = [nodes]
    else:
        nodes = list(nodes)

    non_existing = set(nodes) - set(self.node_properties.keys())
    if non_existing:
        raise HypergraphError(f"Nodes {non_existing} do not exist in the hypergraph.")

    if inplace:
        for node_id in nodes:
            self.remove_node(node_id)
        return self
    else:
        # Create a new Hypergraph instance with nodes removed
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)

        for node_id in nodes:
            del new_node_properties[node_id]
            # Remove node from all hyperedges
            for hyperedge in new_hyperedges.values():
                if node_id in hyperedge.nodes:
                    hyperedge.remove_node(node_id)
            # Remove node from graph, which also removes all incidences
            new_graph.remove_node(node_id)

        # Remove nodes not connected to any hyperedges
        retained_nodes = set()
        for hyperedge in new_hyperedges.values():
            retained_nodes.update(hyperedge.nodes)

        new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}

        # Reconstruct hyperedges dict for __init__
        hyperedges_dict = {
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in new_hyperedges.items()
        }

        return Hypergraph(
            hyperedges=hyperedges_dict,
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=self.name
        )
remove_singleton_hyperedges(name=None)

Constructs a clone of the hypergraph with singleton hyperedges removed.

Source code in src/aeiva/hypergraph/hypergraph.py
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
def remove_singleton_hyperedges(self, name: Optional[str] = None) -> 'Hypergraph':
    """
    Constructs a clone of the hypergraph with singleton hyperedges removed.
    """
    singletons = self.get_singleton_hyperedges()
    if not singletons:
        return self.copy(name=name)

    new_hypergraph = self.remove_hyperedges(singletons, inplace=False)
    new_hypergraph.name = name if name else f"{self.name}_no_singleton_hyperedges"
    return new_hypergraph
restrict_to_specific_hyperedges(hyperedges_to_retain, name=None)

Creates a new hypergraph by retaining only the specified hyperedges and removing all others.

Parameters

hyperedges_to_retain : Iterable[Any] An iterable of hyperedge identifiers to retain in the new hypergraph.

Optional[str], default=None

The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.

Returns

Hypergraph A new hypergraph containing only the specified hyperedges and their associated nodes.

Raises

HypergraphError If none of the specified hyperedges exist in the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
def restrict_to_specific_hyperedges(
    self,
    hyperedges_to_retain: Iterable[Any],
    name: Optional[str] = None
) -> 'Hypergraph':
    """
    Creates a new hypergraph by retaining only the specified hyperedges and removing all others.

    Parameters
    ----------
    hyperedges_to_retain : Iterable[Any]
        An iterable of hyperedge identifiers to retain in the new hypergraph.

    name : Optional[str], default=None
        The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.

    Returns
    -------
    Hypergraph
        A new hypergraph containing only the specified hyperedges and their associated nodes.

    Raises
    ------
    HypergraphError
        If none of the specified hyperedges exist in the hypergraph.
    """
    hyperedges_to_retain = set(hyperedges_to_retain)
    existing_hyperedges = set(self.hyperedges.keys())
    invalid_hyperedges = hyperedges_to_retain - existing_hyperedges
    if invalid_hyperedges:
        raise HypergraphError(f"The following hyperedges do not exist and cannot be retained: {invalid_hyperedges}")

    # Determine hyperedges to remove
    hyperedges_to_remove = existing_hyperedges - hyperedges_to_retain
    if not hyperedges_to_remove:
        # No hyperedges to remove; return the original hypergraph
        return self

    # Remove hyperedges using the existing remove_hyperedges method
    restricted_hypergraph = self.remove_hyperedges(hyperedges_to_remove, inplace=False)
    restricted_hypergraph.name = name if name else f"{self.name}_restricted_hyperedges"

    return restricted_hypergraph
restrict_to_specific_nodes(nodes_to_retain, name=None)

Creates a new hypergraph by retaining only the specified nodes and removing all others.

Parameters

nodes_to_retain : Iterable[Any] An iterable of node identifiers to retain in the new hypergraph.

Optional[str], default=None

The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.

Returns

Hypergraph A new hypergraph containing only the specified nodes and their associated hyperedges.

Raises

HypergraphError If none of the specified nodes exist in the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
def restrict_to_specific_nodes(
    self,
    nodes_to_retain: Iterable[Any],
    name: Optional[str] = None
) -> 'Hypergraph':
    """
    Creates a new hypergraph by retaining only the specified nodes and removing all others.

    Parameters
    ----------
    nodes_to_retain : Iterable[Any]
        An iterable of node identifiers to retain in the new hypergraph.

    name : Optional[str], default=None
        The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.

    Returns
    -------
    Hypergraph
        A new hypergraph containing only the specified nodes and their associated hyperedges.

    Raises
    ------
    HypergraphError
        If none of the specified nodes exist in the hypergraph.
    """
    nodes_to_retain = set(nodes_to_retain)
    existing_nodes = set(self.node_properties.keys())
    invalid_nodes = nodes_to_retain - existing_nodes
    if invalid_nodes:
        raise HypergraphError(f"The following nodes do not exist and cannot be retained: {invalid_nodes}")

    # Determine nodes to remove
    nodes_to_remove = existing_nodes - nodes_to_retain
    if not nodes_to_remove:
        # No nodes to remove; return the original hypergraph
        return self

    # Remove nodes using the existing remove_nodes_from method
    restricted_hypergraph = self.remove_nodes_from(nodes_to_remove, inplace=False)
    restricted_hypergraph.name = name if name else f"{self.name}_restricted_nodes"

    return restricted_hypergraph
s_component_subgraphs(s=1, hyperedges=True, return_singletons=False, name=None)

Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.

Parameters

s : int, optional, default=1 The connectivity level to check. hyperedges : bool, optional, default=True If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.

Yields

Hypergraph Subgraphs representing each connected component.

Source code in src/aeiva/hypergraph/hypergraph.py
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
def s_component_subgraphs(
    self,
    s: int = 1,
    hyperedges: bool = True,
    return_singletons: bool = False,
    name: Optional[str] = None
) -> Iterator['Hypergraph']:
    """
    Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.
    hyperedges : bool, optional, default=True
        If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components.
    return_singletons : bool, optional, default=False
        If True, includes singleton components. Otherwise, excludes them.
    name : Optional[str], default=None
        Base name for the subgraphs. Each subgraph will have a unique name appended.

    Yields
    ------
    Hypergraph
        Subgraphs representing each connected component.
    """
    for idx, component in enumerate(
        self.s_connected_components(s=s, hyperedges=hyperedges, return_singletons=return_singletons)
    ):
        if hyperedges:
            yield self.restrict_to_specific_hyperedges(
                hyperedges_to_retain=component, 
                name=f"{name or self.name}_component_{idx}"
            )
        else:
            yield self.restrict_to_specific_nodes(
                nodes_to_retain=component, 
                name=f"{name or self.name}_component_{idx}"
            )
s_connected_components(s=1, hyperedges=True, return_singletons=False)

Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.

Parameters

s : int, optional, default=1 The connectivity level to check. hyperedges : bool, optional, default=True If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.

Yields

Set[Any] Sets of hyperedge IDs or node IDs representing each connected component.

Source code in src/aeiva/hypergraph/hypergraph.py
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
def s_connected_components(
    self, 
    s: int = 1, 
    hyperedges: bool = True, 
    return_singletons: bool = False
) -> Iterator[Set[Any]]:
    """
    Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.

    Parameters
    ----------
    s : int, optional, default=1
        The connectivity level to check.
    hyperedges : bool, optional, default=True
        If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components.
    return_singletons : bool, optional, default=False
        If True, includes singleton components. Otherwise, excludes them.

    Yields
    ------
    Set[Any]
        Sets of hyperedge IDs or node IDs representing each connected component.
    """
    if hyperedges:
        # s-hyperedge-connected: hyperedges are connected if they share at least s nodes
        hyperedge_graph = nx.Graph()
        hyperedge_ids = list(self.hyperedges.keys())
        hyperedge_graph.add_nodes_from(hyperedge_ids)

        for i, he1 in enumerate(hyperedge_ids):
            nodes1 = self.hyperedges[he1].nodes
            for he2 in hyperedge_ids[i + 1:]:
                nodes2 = self.hyperedges[he2].nodes
                shared_nodes = nodes1 & nodes2
                if len(shared_nodes) >= s:
                    hyperedge_graph.add_edge(he1, he2)

        components = nx.connected_components(hyperedge_graph)
        for component in components:
            if not return_singletons and len(component) == 1:
                continue
            yield component
    else:
        # s-node-connected: nodes are connected if they share at least s hyperedges
        node_graph = nx.Graph()
        node_ids = list(self.node_properties.keys())
        node_graph.add_nodes_from(node_ids)

        for i, node1 in enumerate(node_ids):
            hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}
            for node2 in node_ids[i + 1:]:
                hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}
                shared_hyperedges = hyperedges1 & hyperedges2
                if len(shared_hyperedges) >= s:
                    node_graph.add_edge(node1, node2)

        components = nx.connected_components(node_graph)
        for component in components:
            if not return_singletons and len(component) == 1:
                continue
            yield component
symmetric_difference(other, inplace=False, name=None)

Returns the symmetric difference of the current hypergraph with another hypergraph. The symmetric difference includes elements present in either hypergraph but not in both.

Parameters

other : Hypergraph The hypergraph to symmetric difference with. inplace : bool, optional, default=False If True, modifies the current hypergraph to keep only the symmetric difference elements. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.

Returns

Hypergraph The resulting symmetric difference hypergraph.

Raises

TypeError If other is not an instance of Hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
def symmetric_difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
    """
    Returns the symmetric difference of the current hypergraph with another hypergraph.
    The symmetric difference includes elements present in either hypergraph but not in both.

    Parameters
    ----------
    other : Hypergraph
        The hypergraph to symmetric difference with.
    inplace : bool, optional, default=False
        If True, modifies the current hypergraph to keep only the symmetric difference elements.
        Otherwise, returns a new Hypergraph instance.
    name : Optional[str], default=None
        The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.

    Returns
    -------
    Hypergraph
        The resulting symmetric difference hypergraph.

    Raises
    ------
    TypeError
        If `other` is not an instance of Hypergraph.
    """
    if not isinstance(other, Hypergraph):
        raise TypeError("The `other` parameter must be an instance of Hypergraph.")

    if inplace:
        # Hyperedges symmetric difference
        hyperedges_to_add = set(other.hyperedges.keys()) - set(self.hyperedges.keys())
        hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())
        self.remove_hyperedges(hyperedges_to_remove, inplace=True)
        for he_id in hyperedges_to_add:
            hyperedge = other.hyperedges[he_id]
            self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)

        # Nodes symmetric difference
        nodes_to_add = set(other.node_properties.keys()) - set(self.node_properties.keys())
        nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())
        self.remove_nodes_from(nodes_to_remove, inplace=True)
        for node_id in nodes_to_add:
            props = other.node_properties[node_id]
            self.add_node(node_id, properties=props, inplace=True)

        if name:
            self.name = name
        return self
    else:
        # Create a new Hypergraph instance
        union_hg = self.union(other)
        intersection_hg = self.intersection(other)
        return union_hg.difference(intersection_hg, name=name if name else f"SymmetricDifference_of_{self.name}_{other.name}")
to_bipartite_graph(keep_data=False, directed=False)

Creates a bipartite NetworkX graph from the hypergraph. The nodes and hyperedges of the hypergraph become nodes in the bipartite graph. For every hyperedge in the hypergraph and each node it connects to, there is an edge in the bipartite graph.

Parameters

keep_data : bool, optional, default = False If True, includes the node and hyperedge properties in the NetworkX graph. directed : bool, optional, default = False If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.

Returns

networkx.Graph or networkx.DiGraph The bipartite graph representation of the hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
def to_bipartite_graph(self, keep_data=False, directed=False) -> nx.Graph:
    """
    Creates a bipartite NetworkX graph from the hypergraph.
    The nodes and hyperedges of the hypergraph become nodes in the bipartite graph.
    For every hyperedge in the hypergraph and each node it connects to, there
    is an edge in the bipartite graph.

    Parameters
    ----------
    keep_data : bool, optional, default = False
        If True, includes the node and hyperedge properties in the NetworkX graph.
    directed : bool, optional, default = False
        If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.

    Returns
    -------
    networkx.Graph or networkx.DiGraph
        The bipartite graph representation of the hypergraph.
    """
    # Choose graph type based on directed flag
    B = nx.DiGraph() if directed else nx.Graph()

    if not keep_data:
        # Add nodes with bipartite attributes, where 0 indicates hyperedges and 1 indicates regular nodes
        B.add_nodes_from(self.hyperedges.keys(), bipartite=0)  # hyperedges
        B.add_nodes_from(self.node_properties.keys(), bipartite=1)  # nodes

        # Add edges between hyperedges and nodes based on hyperedges data
        for he_id, hyperedge in self.hyperedges.items():
            for node in hyperedge.nodes:
                B.add_edge(he_id, node)
    else:
        # Add nodes with properties if keep_data is True
        for node_id, properties in self.node_properties.items():
            B.add_node(node_id, bipartite=1, **properties)

        for he_id, hyperedge in self.hyperedges.items():
            B.add_node(he_id, bipartite=0, **self.hyperedge_properties.get(he_id, {}))
            for node in hyperedge.nodes:
                # Add edges with optional properties if keep_data is True
                B.add_edge(he_id, node)

    return B
transpose(name=None)

Transposes the hypergraph by swapping the roles of nodes and hyperedges. The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.

Parameters

name : Optional[str], default=None The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.

Returns

Hypergraph The transposed hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
def transpose(self, name: Optional[str] = None) -> 'Hypergraph':
    """
    Transposes the hypergraph by swapping the roles of nodes and hyperedges.
    The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.

    Parameters
    ----------
    name : Optional[str], default=None
        The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.

    Returns
    -------
    Hypergraph
        The transposed hypergraph.
    """
    transposed_hyperedges = {node_id: HyperEdge(id=node_id, nodes=set(), properties=copy.deepcopy(props))
                             for node_id, props in self.node_properties.items()}
    transposed_node_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items()}

    for he_id, hyperedge in self.hyperedges.items():
        for node in hyperedge.nodes:
            if node in transposed_hyperedges:
                transposed_hyperedges[node].nodes.add(he_id)

    # Construct the transposed hypergraph
    return Hypergraph(
        hyperedges={
            he_id: {
                'nodes': list(he.nodes),
                'properties': he.properties.copy()
            } for he_id, he in transposed_hyperedges.items()
        },
        node_properties=transposed_node_properties,
        hyperedge_properties={he_id: he.properties.copy() for he_id, he in transposed_hyperedges.items()},
        name=name if name else f"{self.name}_transposed"
    )
union(other, inplace=False, name=None)

Returns the union of the current hypergraph with another hypergraph. The union combines all nodes and hyperedges from both hypergraphs.

Parameters

other : Hypergraph The hypergraph to union with. inplace : bool, optional, default=False If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.

Returns

Hypergraph The resulting union hypergraph.

Raises

TypeError If other is not an instance of Hypergraph.

Source code in src/aeiva/hypergraph/hypergraph.py
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
def union(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':
    """
    Returns the union of the current hypergraph with another hypergraph.
    The union combines all nodes and hyperedges from both hypergraphs.

    Parameters
    ----------
    other : Hypergraph
        The hypergraph to union with.
    inplace : bool, optional, default=False
        If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance.
    name : Optional[str], default=None
        The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.

    Returns
    -------
    Hypergraph
        The resulting union hypergraph.

    Raises
    ------
    TypeError
        If `other` is not an instance of Hypergraph.
    """
    if not isinstance(other, Hypergraph):
        raise TypeError("The `other` parameter must be an instance of Hypergraph.")

    if inplace:
        # Add nodes from other
        for node_id, props in other.node_properties.items():
            if node_id not in self.node_properties:
                self.add_node(node_id, properties=props, inplace=True)
            else:
                # Optionally, merge properties
                self.node_properties[node_id].update(props)
                self.graph.nodes[node_id].update(props)

        # Add hyperedges from other
        for he_id, hyperedge in other.hyperedges.items():
            if he_id not in self.hyperedges:
                self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)
            else:
                # Optionally, merge properties and nodes
                self.hyperedges[he_id].nodes.update(hyperedge.nodes)
                self.hyperedge_properties[he_id].update(hyperedge.properties)
                for node in hyperedge.nodes:
                    if node not in self.graph:
                        self.add_node(node)
                    self.graph.add_edge(he_id, node)
        if name:
            self.name = name
        return self
    else:
        # Create a new Hypergraph instance
        new_hyperedges = copy.deepcopy(self.hyperedges)
        new_node_properties = copy.deepcopy(self.node_properties)
        new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)
        new_graph = copy.deepcopy(self.graph)
        new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)
        new_name = name if name else f"Union_of_{self.name}_{other.name}"

        # Add nodes from other
        for node_id, props in other.node_properties.items():
            if node_id not in new_node_properties:
                new_node_properties[node_id] = copy.deepcopy(props)
                new_graph.add_node(node_id, bipartite='node', **props)

        # Add hyperedges from other
        for he_id, hyperedge in other.hyperedges.items():
            if he_id not in new_hyperedges:
                new_hyperedges[he_id] = copy.deepcopy(hyperedge)
                new_hyperedge_properties[he_id] = copy.deepcopy(other.hyperedge_properties[he_id])
                new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])
                new_bipartite_nodes.add(he_id)
                for node in hyperedge.nodes:
                    new_graph.add_edge(he_id, node)
            else:
                # Merge nodes and properties
                new_hyperedges[he_id].nodes.update(hyperedge.nodes)
                new_hyperedge_properties[he_id].update(other.hyperedge_properties[he_id])
                for node in hyperedge.nodes:
                    new_graph.add_edge(he_id, node)

        # Construct the new Hypergraph
        return Hypergraph(
            hyperedges={
                he_id: {
                    'nodes': list(he.nodes),
                    'properties': he.properties.copy()
                } for he_id, he in new_hyperedges.items()
            },
            node_properties=new_node_properties,
            hyperedge_properties=new_hyperedge_properties,
            name=new_name
        )

visualization

draw_hyper_edge_labels(H, pos, polys, labels={}, edge_labels_on_edge=True, ax=None, **kwargs)

Draws a label on the hyper edge boundary.

Should be passed Matplotlib PolyCollection representing the hyper-edges, see the return value of draw_hyper_edges.

The label will be draw on the least curvy part of the polygon, and will be aligned parallel to the orientation of the polygon where it is drawn.

Parameters

H: hnx.Hypergraph the entity to be drawn polys: PolyCollection collection of polygons returned by draw_hyper_edges labels: dict mapping of node id to string label ax: Axis matplotlib axis on which the plot is rendered kwargs: dict Keyword arguments are passed through to Matplotlib's annotate function.

Source code in src/aeiva/hypergraph/visualization.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def draw_hyper_edge_labels(
    H, pos, polys, labels={}, edge_labels_on_edge=True, ax=None, **kwargs
):
    """
    Draws a label on the hyper edge boundary.

    Should be passed Matplotlib PolyCollection representing the hyper-edges, see
    the return value of draw_hyper_edges.

    The label will be draw on the least curvy part of the polygon, and will be
    aligned parallel to the orientation of the polygon where it is drawn.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    polys: PolyCollection
        collection of polygons returned by draw_hyper_edges
    labels: dict
        mapping of node id to string label
    ax: Axis
        matplotlib axis on which the plot is rendered
    kwargs: dict
        Keyword arguments are passed through to Matplotlib's annotate function.

    """
    ax = ax or plt.gca()

    params = transpose_inflated_kwargs(inflate_kwargs(H.edges(), kwargs))

    for edge, path, params in zip(H.edges(), polys.get_paths(), params):
        s = labels.get(edge, edge)

        theta = 0
        xy = None

        if edge_labels_on_edge:
            # calculate the xy location of the annotation
            # this is the midpoint of the pair of adjacent points the most distant
            d = ((path.vertices[:-1] - path.vertices[1:]) ** 2).sum(axis=1)
            i = d.argmax()

            x1, x2 = path.vertices[i : i + 2]
            x, y = x2 - x1
            theta = 360 * np.arctan2(y, x) / (2 * np.pi)
            theta = (theta + 360) % 360

            while theta > 90:
                theta -= 180

            xy = (x1 + x2) / 2
        else:
            xy = pos[edge]

        # the string is a comma separated list of the edge uid
        ax.annotate(s, xy, rotation=theta, ha="center", va="center", **params)

draw_hyper_edges(H, pos, ax=None, node_radius={}, contain_hyper_edges=False, dr=None, **kwargs)

Draws a convex hull around the nodes contained within each edge in H

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) dr: float the spacing between concentric rings ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor

Returns

PolyCollection a Matplotlib PolyCollection that can be further styled

Source code in src/aeiva/hypergraph/visualization.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def draw_hyper_edges(
    H, pos, ax=None, node_radius={}, contain_hyper_edges=False, dr=None, **kwargs
):
    """
    Draws a convex hull around the nodes contained within each edge in H

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2
    node_radius: dict
        mapping of node to R^1 (radius of each node)
    dr: float
        the spacing between concentric rings
    ax: Axis
        matplotlib axis on which the plot is rendered
    kwargs: dict
        keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor

    Returns
    -------
    PolyCollection
        a Matplotlib PolyCollection that can be further styled
    """
    points = layout_hyper_edges(
        H, pos, node_radius=node_radius, dr=dr, contain_hyper_edges=contain_hyper_edges
    )

    polys = PolyCollection(points, **inflate_kwargs(H.edges(), kwargs))

    (ax or plt.gca()).add_collection(polys)

    return polys

draw_hyper_edges_two_column(H, pos, ax=None, **kwargs)

Renders hyper edges for the two column layout.

Each node-hyper edge membership is rendered as a line connecting the node in the left column to the edge in the right column.

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments passed to matplotlib.LineCollection

Returns

LineCollection the hyper edges

Source code in src/aeiva/hypergraph/visualization.py
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
def draw_hyper_edges_two_column(H, pos, ax=None, **kwargs):
    """
    Renders hyper edges for the two column layout.

    Each node-hyper edge membership is rendered as a line connecting the node
    in the left column to the edge in the right column.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2
    ax: Axis
        matplotlib axis on which the plot is rendered
    kwargs: dict
        keyword arguments passed to matplotlib.LineCollection

    Returns
    -------
    LineCollection
        the hyper edges
    """
    ax = ax or plt.gca()

    pairs = [(v, e) for e in H.edges() for v in H.edge_elements()[e]]

    kwargs = {
        k: v if type(v) != dict else [v.get(e) for _, e in pairs]
        for k, v in kwargs.items()
    }

    lines = LineCollection([(pos[u], pos[v]) for u, v in pairs], **kwargs)

    ax.add_collection(lines)

    return lines

draw_hyper_labels(H, pos, node_radius={}, ax=None, labels={}, **kwargs)

Draws text labels for the hypergraph nodes.

The label is drawn to the right of the node. The node radius is needed (see draw_hyper_nodes) so the text can be offset appropriately as the node size changes.

The text label can be customized by passing in a dictionary, labels, mapping a node to its custom label. By default, the label is the string representation of the node.

Keyword arguments are passed through to Matplotlib's annotate function.

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) ax: Axis matplotlib axis on which the plot is rendered labels: dict mapping of node to text label kwargs: dict keyword arguments passed to matplotlib.annotate

Source code in src/aeiva/hypergraph/visualization.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def draw_hyper_labels(H, pos, node_radius={}, ax=None, labels={}, **kwargs):
    """
    Draws text labels for the hypergraph nodes.

    The label is drawn to the right of the node. The node radius is needed (see
    draw_hyper_nodes) so the text can be offset appropriately as the node size
    changes.

    The text label can be customized by passing in a dictionary, labels, mapping
    a node to its custom label. By default, the label is the string
    representation of the node.

    Keyword arguments are passed through to Matplotlib's annotate function.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2
    node_radius: dict
        mapping of node to R^1 (radius of each node)
    ax: Axis
        matplotlib axis on which the plot is rendered
    labels: dict
        mapping of node to text label
    kwargs: dict
        keyword arguments passed to matplotlib.annotate

    """
    ax = ax or plt.gca()
    params = transpose_inflated_kwargs(inflate_kwargs(H.nodes(), kwargs))

    for v, v_kwargs in zip(iter(H.nodes()), params):
        xy = np.array([node_radius.get(v, 0), 0]) + pos[v]
        ax.annotate(
            labels.get(v, v),
            xy,
            **{
                k: (
                    d[v]
                    if hasattr(d, "__getitem__") and type(d) not in {str, tuple}
                    else d
                )
                for k, d in kwargs.items()
            }
        )

draw_hyper_labels_two_column(H, pos, labels={}, with_node_labels=True, with_edge_labels=True, ax=None)

Renders hyper labels (nodes and edges) for the two column layout.

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 labels: dict custom labels for nodes and edges can be supplied with_node_labels: bool False to disable node labels with_edge_labels: bool False to disable edge labels ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments passed to matplotlib.LineCollection

Source code in src/aeiva/hypergraph/visualization.py
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
def draw_hyper_labels_two_column(
    H, pos, labels={}, with_node_labels=True, with_edge_labels=True, ax=None
):
    """
    Renders hyper labels (nodes and edges) for the two column layout.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2
    labels: dict
        custom labels for nodes and edges can be supplied
    with_node_labels: bool
        False to disable node labels
    with_edge_labels: bool
        False to disable edge labels
    ax: Axis
        matplotlib axis on which the plot is rendered
    kwargs: dict
        keyword arguments passed to matplotlib.LineCollection

    """

    ax = ax or plt.gca()

    to_draw = []
    if with_node_labels:
        to_draw.append((list(H.nodes()), "right"))

    if with_edge_labels:
        to_draw.append((list(H.edges()), "left"))

    for points, ha in to_draw:
        for p in points:
            ax.annotate(labels.get(p, p), pos[p], ha=ha, va="center")

draw_hyper_nodes(H, pos, node_radius={}, r0=None, ax=None, **kwargs)

Draws a circle for each node in H.

The position of each node is specified by the a dictionary/list-like, pos, where pos[v] is the xy-coordinate for the vertex. The radius of each node can be specified as a dictionary where node_radius[v] is the radius. If a node is missing from this dictionary, or the node_radius is not specified at all, a sensible default radius is chosen based on distances between nodes given by pos.

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) r0: float minimum distance that concentric rings start from the node position ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor

Returns

PolyCollection a Matplotlib PolyCollection that can be further styled

Source code in src/aeiva/hypergraph/visualization.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def draw_hyper_nodes(H, pos, node_radius={}, r0=None, ax=None, **kwargs):
    """
    Draws a circle for each node in H.

    The position of each node is specified by the a dictionary/list-like, pos,
    where pos[v] is the xy-coordinate for the vertex. The radius of each node
    can be specified as a dictionary where node_radius[v] is the radius. If a
    node is missing from this dictionary, or the node_radius is not specified at
    all, a sensible default radius is chosen based on distances between nodes
    given by pos.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2
    node_radius: dict
        mapping of node to R^1 (radius of each node)
    r0: float
        minimum distance that concentric rings start from the node position
    ax: Axis
        matplotlib axis on which the plot is rendered
    kwargs: dict
        keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor

    Returns
    -------
    PolyCollection
        a Matplotlib PolyCollection that can be further styled
    """

    ax = ax or plt.gca()

    r0 = r0 or get_default_radius(H, pos)

    points = [node_radius.get(v, r0) * cp + pos[v] for v in H.nodes()]

    kwargs.setdefault("facecolors", "black")

    circles = PolyCollection(points, **inflate_kwargs(H, kwargs))

    ax.add_collection(circles)

    return circles

draw_rubber_band(H, pos=None, with_color=True, with_node_counts=False, with_edge_counts=False, layout=nx.spring_layout, layout_kwargs={}, ax=None, node_radius=None, edges_kwargs={}, nodes_kwargs={}, edge_labels_on_edge=True, edge_labels={}, edge_labels_kwargs={}, node_labels={}, node_labels_kwargs={}, with_edge_labels=True, with_node_labels=True, node_label_alpha=0.35, edge_label_alpha=0.35, with_additional_edges=None, contain_hyper_edges=False, additional_edges_kwargs={}, return_pos=False)

Draw a hypergraph as a Matplotlib figure

By default this will draw a colorful "rubber band" like hypergraph, where convex hulls represent edges and are drawn around the nodes they contain.

This is a convenience function that wraps calls with sensible parameters to the following lower-level drawing functions:

  • draw_hyper_edges,
  • draw_hyper_edge_labels,
  • draw_hyper_labels, and
  • draw_hyper_nodes

The default layout algorithm is nx.spring_layout, but other layouts can be passed in. The Hypergraph is converted to a bipartite graph, and the layout algorithm is passed the bipartite graph.

If you have a pre-determined layout, you can pass in a "pos" dictionary. This is a dictionary mapping from node id's to x-y coordinates. For example:

>>> pos = {
>>> 'A': (0, 0),
>>> 'B': (1, 2),
>>> 'C': (5, -3)
>>> }

will position the nodes {A, B, C} manually at the locations specified. The coordinate system is in Matplotlib "data coordinates", and the figure will be centered within the figure.

By default, this will draw in a new figure, but the axis to render in can be specified using :code:ax.

This approach works well for small hypergraphs, and does not guarantee a rigorously "correct" drawing. Overlapping of sets in the drawing generally implies that the sets intersect, but sometimes sets overlap if there is no intersection. It is not possible, in general, to draw a "correct" hypergraph this way for an arbitrary hypergraph, in the same way that not all graphs have planar drawings.

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 with_color: bool set to False to disable color cycling of edges with_node_counts: bool set to True to replace the label for collapsed nodes with the number of elements with_edge_counts: bool set to True to label collapsed edges with number of elements layout: function layout algorithm to compute layout_kwargs: dict keyword arguments passed to layout function ax: Axis matplotlib axis on which the plot is rendered edges_kwargs: dict keyword arguments passed to matplotlib.collections.PolyCollection for edges node_radius: None, int, float, or dict radius of all nodes, or dictionary of node:value; the default (None) calculates radius based on number of collapsed nodes; reasonable values range between 1 and 3 nodes_kwargs: dict keyword arguments passed to matplotlib.collections.PolyCollection for nodes edge_labels_on_edge: bool whether to draw edge labels on the edge (rubber band) or inside edge_labels_kwargs: dict keyword arguments passed to matplotlib.annotate for edge labels node_labels_kwargs: dict keyword argumetns passed to matplotlib.annotate for node labels with_edge_labels: bool set to False to make edge labels invisible with_node_labels: bool set to False to make node labels invisible node_label_alpha: float the transparency (alpha) of the box behind text drawn in the figure for node labels edge_label_alpha: float the transparency (alpha) of the box behind text drawn in the figure for edge labels with_additional_edges: networkx.Graph ... contain_hyper_edges: bool whether the rubber band shoudl be drawn around the location of the edge in the bipartite graph. This may be invisibile unless "with_additional_edges" contains this information.

Source code in src/aeiva/hypergraph/visualization.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
def draw_rubber_band(
    H,
    pos=None,
    with_color=True,
    with_node_counts=False,
    with_edge_counts=False,
    layout=nx.spring_layout,
    layout_kwargs={},
    ax=None,
    node_radius=None,
    edges_kwargs={},
    nodes_kwargs={},
    edge_labels_on_edge=True,
    edge_labels={},
    edge_labels_kwargs={},
    node_labels={},
    node_labels_kwargs={},
    with_edge_labels=True,
    with_node_labels=True,
    node_label_alpha=0.35,
    edge_label_alpha=0.35,
    with_additional_edges=None,
    contain_hyper_edges=False,
    additional_edges_kwargs={},
    return_pos=False,
):
    """
    Draw a hypergraph as a Matplotlib figure

    By default this will draw a colorful "rubber band" like hypergraph, where
    convex hulls represent edges and are drawn around the nodes they contain.

    This is a convenience function that wraps calls with sensible parameters to
    the following lower-level drawing functions:

    * draw_hyper_edges,
    * draw_hyper_edge_labels,
    * draw_hyper_labels, and
    * draw_hyper_nodes

    The default layout algorithm is nx.spring_layout, but other layouts can be
    passed in. The Hypergraph is converted to a bipartite graph, and the layout
    algorithm is passed the bipartite graph.

    If you have a pre-determined layout, you can pass in a "pos" dictionary.
    This is a dictionary mapping from node id's to x-y coordinates. For example:

        >>> pos = {
        >>> 'A': (0, 0),
        >>> 'B': (1, 2),
        >>> 'C': (5, -3)
        >>> }

    will position the nodes {A, B, C} manually at the locations specified. The
    coordinate system is in Matplotlib "data coordinates", and the figure will
    be centered within the figure.

    By default, this will draw in a new figure, but the axis to render in can be
    specified using :code:`ax`.

    This approach works well for small hypergraphs, and does not guarantee
    a rigorously "correct" drawing. Overlapping of sets in the drawing generally
    implies that the sets intersect, but sometimes sets overlap if there is no
    intersection. It is not possible, in general, to draw a "correct" hypergraph
    this way for an arbitrary hypergraph, in the same way that not all graphs
    have planar drawings.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2
    with_color: bool
        set to False to disable color cycling of edges
    with_node_counts: bool
        set to True to replace the label for collapsed nodes with the number of elements
    with_edge_counts: bool
        set to True to label collapsed edges with number of elements
    layout: function
        layout algorithm to compute
    layout_kwargs: dict
        keyword arguments passed to layout function
    ax: Axis
        matplotlib axis on which the plot is rendered
    edges_kwargs: dict
        keyword arguments passed to matplotlib.collections.PolyCollection for edges
    node_radius: None, int, float, or dict
        radius of all nodes, or dictionary of node:value; the default (None) calculates radius based on number of collapsed nodes; reasonable values range between 1 and 3
    nodes_kwargs: dict
        keyword arguments passed to matplotlib.collections.PolyCollection for nodes
    edge_labels_on_edge: bool
        whether to draw edge labels on the edge (rubber band) or inside
    edge_labels_kwargs: dict
        keyword arguments passed to matplotlib.annotate for edge labels
    node_labels_kwargs: dict
        keyword argumetns passed to matplotlib.annotate for node labels
    with_edge_labels: bool
        set to False to make edge labels invisible
    with_node_labels: bool
        set to False to make node labels invisible
    node_label_alpha: float
        the transparency (alpha) of the box behind text drawn in the figure for node labels
    edge_label_alpha: float
        the transparency (alpha) of the box behind text drawn in the figure for edge labels
    with_additional_edges: networkx.Graph
        ...
    contain_hyper_edges: bool
        whether the rubber band shoudl be drawn around the location of the edge in the bipartite graph. This may be invisibile unless "with_additional_edges" contains this information.

    """

    ax = ax or plt.gca()

    if pos is None:
        pos = layout_node_link(H, with_additional_edges, layout=layout, **layout_kwargs)

    r0 = get_default_radius(H, pos)
    a0 = np.pi * r0**2

    def get_node_radius(v):
        if node_radius is None:
            return np.sqrt(a0 * get_collapsed_size(v) / np.pi)
        elif hasattr(node_radius, "get"):
            return node_radius.get(v, 1) * r0
        return node_radius * r0

    # guarantee that node radius is a dictionary mapping nodes to values
    node_radius = {v: get_node_radius(v) for v in H.nodes()}

    # for convenience, we are using setdefault to mutate the argument
    # however, we need to copy this to prevent side-effects
    edges_kwargs = edges_kwargs.copy()
    edges_kwargs.setdefault("edgecolors", plt.cm.tab10(np.arange(len((H.edges()))) % 10))
    edges_kwargs.setdefault("facecolors", "none")

    polys = draw_hyper_edges(
        H,
        pos,
        node_radius=node_radius,
        ax=ax,
        contain_hyper_edges=contain_hyper_edges,
        **edges_kwargs
    )

    if with_additional_edges:
        nx.draw_networkx_edges(
            with_additional_edges,
            pos=pos,
            ax=ax,
            **inflate_kwargs(with_additional_edges.edges(), additional_edges_kwargs)
        )

    if with_edge_labels:
        labels = get_frozenset_label(
            H.edges(), count=with_edge_counts, override=edge_labels
        )

        draw_hyper_edge_labels(
            H,
            pos,
            polys,
            color=edges_kwargs["edgecolors"],
            backgroundcolor=(1, 1, 1, edge_label_alpha),
            labels=labels,
            ax=ax,
            edge_labels_on_edge=edge_labels_on_edge,
            **edge_labels_kwargs
        )

    if with_node_labels:
        labels = get_frozenset_label(
            H.nodes(), count=with_node_counts, override=node_labels
        )

        draw_hyper_labels(
            H,
            pos,
            node_radius=node_radius,
            labels=labels,
            ax=ax,
            va="center",
            xytext=(5, 0),
            textcoords="offset points",
            backgroundcolor=(1, 1, 1, node_label_alpha),
            **node_labels_kwargs
        )

    draw_hyper_nodes(H, pos, node_radius=node_radius, ax=ax, **nodes_kwargs)

    if len(H.nodes()) == 1:
        x, y = pos[list(H.nodes())[0]]
        s = 20

        ax.axis([x - s, x + s, y - s, y + s])
    else:
        ax.axis("equal")

    ax.axis("off")
    if return_pos:
        return pos

draw_two_column(H, with_node_labels=True, with_edge_labels=True, with_node_counts=False, with_edge_counts=False, with_color=True, edge_kwargs=None, ax=None)

Draw a hypergraph using a two-collumn layout.

This is intended reproduce an illustrative technique for bipartite graphs and hypergraphs that is typically used in papers and textbooks.

The left column is reserved for nodes and the right column is reserved for edges. A line is drawn between a node an an edge

The order of nodes and edges is optimized to reduce line crossings between the two columns. Spacing between disconnected components is adjusted to make the diagram easier to read, by reducing the angle of the lines.

Parameters

H: hnx.Hypergraph the entity to be drawn with_node_labels: bool False to disable node labels with_edge_labels: bool False to disable edge labels with_node_counts: bool set to True to label collapsed nodes with number of elements with_edge_counts: bool set to True to label collapsed edges with number of elements with_color: bool set to False to disable color cycling of hyper edges edge_kwargs: dict keyword arguments to pass to matplotlib.LineCollection ax: Axis matplotlib axis on which the plot is rendered

Source code in src/aeiva/hypergraph/visualization.py
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
def draw_two_column(
    H,
    with_node_labels=True,
    with_edge_labels=True,
    with_node_counts=False,
    with_edge_counts=False,
    with_color=True,
    edge_kwargs=None,
    ax=None,
):
    """
    Draw a hypergraph using a two-collumn layout.

    This is intended reproduce an illustrative technique for bipartite graphs
    and hypergraphs that is typically used in papers and textbooks.

    The left column is reserved for nodes and the right column is reserved for
    edges. A line is drawn between a node an an edge

    The order of nodes and edges is optimized to reduce line crossings between
    the two columns. Spacing between disconnected components is adjusted to make
    the diagram easier to read, by reducing the angle of the lines.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    with_node_labels: bool
        False to disable node labels
    with_edge_labels: bool
        False to disable edge labels
    with_node_counts: bool
        set to True to label collapsed nodes with number of elements
    with_edge_counts: bool
        set to True to label collapsed edges with number of elements
    with_color: bool
        set to False to disable color cycling of hyper edges
    edge_kwargs: dict
        keyword arguments to pass to matplotlib.LineCollection
    ax: Axis
        matplotlib axis on which the plot is rendered
    """

    edge_kwargs = edge_kwargs or {}

    ax = ax or plt.gca()

    pos = layout_two_column(H)

    V = [v for v in H.nodes()]
    E = [e for e in H.edges()]

    labels = {}
    labels.update(get_frozenset_label(V, count=with_node_counts))
    labels.update(get_frozenset_label(E, count=with_edge_counts))

    if with_color:
        edge_kwargs["color"] = {
            e: plt.cm.tab10(i % 10) for i, e in enumerate(H.edges())
        }

    draw_hyper_edges_two_column(H, pos, ax=ax, **edge_kwargs)
    draw_hyper_labels_two_column(
        H,
        pos,
        labels,
        ax=ax,
        with_node_labels=with_node_labels,
        with_edge_labels=with_edge_labels,
    )
    ax.autoscale_view()

    ax.axis("off")

get_default_radius(H, pos)

Calculate a reasonable default node radius

This function iterates over the hyper edges and finds the most distant pair of points given the positions provided. Then, the node radius is a fraction of the median of this distance take across all hyper-edges.

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2

Returns

float the recommended radius

Source code in src/aeiva/hypergraph/visualization.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def get_default_radius(H, pos):
    """
    Calculate a reasonable default node radius

    This function iterates over the hyper edges and finds the most distant
    pair of points given the positions provided. Then, the node radius is a fraction
    of the median of this distance take across all hyper-edges.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2

    Returns
    -------
    float
        the recommended radius

    """
    if len(H) > 1:
        return 0.0125 * np.median(
            [pdist(np.vstack(list(map(pos.get, H.nodes())))).max() for nodes in H.edges()]
        )
    return 1

get_frozenset_label(S, count=False, override={})

Helper function for rendering the labels of possibly collapsed nodes and edges

Parameters

S: iterable list of entities to be labeled count: bool True if labels should be counts of entities instead of list

Returns

dict mapping of entity to its string representation

Source code in src/aeiva/hypergraph/visualization.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def get_frozenset_label(S, count=False, override={}):
    """
    Helper function for rendering the labels of possibly collapsed nodes and edges

    Parameters
    ----------
    S: iterable
        list of entities to be labeled
    count: bool
        True if labels should be counts of entities instead of list

    Returns
    -------
    dict
        mapping of entity to its string representation
    """

    def helper(v):
        if type(v) == str:
            n = get_collapsed_size(v)
            if count and n > 1:
                return f"x {n}"
            elif count:
                return ""
        return str(v)

    return {v: override.get(v, helper(v)) for v in S}

get_line_graph(H, collapse=True)

Computes the line graph, a directed graph, where a directed edge (u, v) exists if the edge u is a subset of the edge v in the hypergraph.

Parameters

H: hnx.Hypergraph the entity to be drawn collapse: bool True if edges should be added if hyper edges are identical

Returns

networkx.DiGraph A directed graph

Source code in src/aeiva/hypergraph/visualization.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def get_line_graph(H, collapse=True):
    """
    Computes the line graph, a directed graph, where a directed edge (u, v)
    exists if the edge u is a subset of the edge v in the hypergraph.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    collapse: bool
        True if edges should be added if hyper edges are identical

    Returns
    -------
    networkx.DiGraph
        A directed graph
    """
    D = nx.DiGraph()

    V = {edge: set(nodes) for edge, nodes in H.edge_elements().items()}

    D.add_nodes_from(V)

    for u, v in combinations(V, 2):
        if V[u] != V[v] or not collapse:
            if V[u].issubset(V[v]):
                D.add_edge(u, v)
            elif V[v].issubset(V[u]):
                D.add_edge(v, u)

    return D

get_set_layering(H, collapse=True)

Computes a layering of the edges in the hyper graph.

In this layering, each edge is assigned a level. An edge u will be above (e.g., have a smaller level value) another edge v if v is a subset of u.

Parameters

H: hnx.Hypergraph the entity to be drawn collapse: bool True if edges should be added if hyper edges are identical

Returns

dict a mapping of vertices in H to integer levels

Source code in src/aeiva/hypergraph/visualization.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def get_set_layering(H, collapse=True):
    """
    Computes a layering of the edges in the hyper graph.

    In this layering, each edge is assigned a level. An edge u will be above
    (e.g., have a smaller level value) another edge v if v is a subset of u.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    collapse: bool
        True if edges should be added if hyper edges are identical

    Returns
    -------
    dict
        a mapping of vertices in H to integer levels
    """

    D = get_line_graph(H, collapse=collapse)

    levels = {}

    for v in nx.topological_sort(D):
        parent_levels = [levels[u] for u, _ in D.in_edges(v)]
        levels[v] = max(parent_levels) + 1 if len(parent_levels) else 0

    return levels

inflate_kwargs(items, kwargs)

Helper function to expand keyword arguments.

Parameters

n: int length of resulting list if argument is expanded kwargs: dict keyword arguments to be expanded

Returns

dict dictionary with same keys as kwargs and whose values are lists of length n

Source code in src/aeiva/hypergraph/visualization.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def inflate_kwargs(items, kwargs):
    """
    Helper function to expand keyword arguments.

    Parameters
    ----------
    n: int
        length of resulting list if argument is expanded
    kwargs: dict
        keyword arguments to be expanded

    Returns
    -------
    dict
        dictionary with same keys as kwargs and whose values are lists of length n
    """

    return {k: inflate(items, v) for k, v in kwargs.items()}

layout_hyper_edges(H, pos, node_radius={}, dr=None, contain_hyper_edges=False)

Draws a convex hull for each edge in H.

Position of the nodes in the graph is specified by the position dictionary, pos. Convex hulls are spaced out such that if one set contains another, the convex hull will surround the contained set. The amount of spacing added between hulls is specified by the parameter, dr.

Parameters

H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) dr: float the spacing between concentric rings ax: Axis matplotlib axis on which the plot is rendered

Returns

dict A mapping from hyper edge ids to paths (Nx2 numpy matrices)

Source code in src/aeiva/hypergraph/visualization.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def layout_hyper_edges(H, pos, node_radius={}, dr=None, contain_hyper_edges=False):
    """
    Draws a convex hull for each edge in H.

    Position of the nodes in the graph is specified by the position dictionary,
    pos. Convex hulls are spaced out such that if one set contains another, the
    convex hull will surround the contained set. The amount of spacing added
    between hulls is specified by the parameter, dr.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    pos: dict
        mapping of node and edge positions to R^2
    node_radius: dict
        mapping of node to R^1 (radius of each node)
    dr: float
        the spacing between concentric rings
    ax: Axis
        matplotlib axis on which the plot is rendered

    Returns
    -------
    dict
        A mapping from hyper edge ids to paths (Nx2 numpy matrices)
    """

    if len(node_radius):
        r0 = min(node_radius.values())
    else:
        r0 = get_default_radius(H, pos)

    dr = dr or r0

    levels = get_set_layering(H)

    radii = {
        v: {v: i for i, v in enumerate(sorted(e, key=levels.get))}
        for v, e in H.node_memberships().items()
    }

    def get_padded_hull(uid, edge):
        # make sure the edge contains at least one node
        if len(edge):
            points = [
                cp * (node_radius.get(v, r0) + dr * (2 + radii[v][uid])) + pos[v]
                for v in edge
            ]

            if contain_hyper_edges:
                points.append(cp * r0 + pos[uid])

            points = np.vstack(points)

        # if not, draw an empty edge centered around the location of the edge node (in the bipartite graph)
        else:
            points = 4 * r0 * cp + pos[uid]

        hull = ConvexHull(points)

        return hull.points[hull.vertices]

    return [get_padded_hull(uid, list(H.edge_elements()[uid])) for uid in H.edges()]

Helper function to use a NetwrokX-like graph layout algorithm on a Hypergraph

The hypergraph is converted to a bipartite graph, allowing the usual graph layout techniques to be applied.

H: hnx.Hypergraph the entity to be drawn G: Graph an additional set of links to consider during the layout process layout: function the layout algorithm which accepts a NetworkX graph and keyword arguments kwargs: dict Keyword arguments are passed through to the layout algorithm

dict mapping of node and edge positions to R^2

Source code in src/aeiva/hypergraph/visualization.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def layout_node_link(H, G=None, layout=nx.spring_layout, **kwargs):
    """
    Helper function to use a NetwrokX-like graph layout algorithm on a Hypergraph

    The hypergraph is converted to a bipartite graph, allowing the usual graph layout
    techniques to be applied.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    G: Graph
        an additional set of links to consider during the layout process
    layout: function
        the layout algorithm which accepts a NetworkX graph and keyword arguments
    kwargs: dict
        Keyword arguments are passed through to the layout algorithm

    Returns
    -------
    dict
        mapping of node and edge positions to R^2
    """

    B = H.to_bipartite_graph()

    if G is not None:
        B.add_edges_from(G.edges())

    return layout(B, **kwargs)

layout_two_column(H, spacing=2)

Two column (bipartite) layout algorithm.

This algorithm first converts the hypergraph into a bipartite graph and then computes connected components. Disonneccted components are handled independently and then stacked together.

Within a connected component, the spectral ordering of the bipartite graph provides a quick and dirty ordering that minimizes edge crossings in the diagram.

Parameters

H: hnx.Hypergraph the entity to be drawn spacing: float amount of whitespace between disconnected components

Source code in src/aeiva/hypergraph/visualization.py
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
def layout_two_column(H, spacing=2):
    """
    Two column (bipartite) layout algorithm.

    This algorithm first converts the hypergraph into a bipartite graph and
    then computes connected components. Disonneccted components are handled
    independently and then stacked together.

    Within a connected component, the spectral ordering of the bipartite graph
    provides a quick and dirty ordering that minimizes edge crossings in the
    diagram.

    Parameters
    ----------
    H: hnx.Hypergraph
        the entity to be drawn
    spacing: float
        amount of whitespace between disconnected components
    """
    offset = 0
    pos = {}

    def stack(vertices, x, height):
        for i, v in enumerate(vertices):
            pos[v] = (x, i + offset + (height - len(vertices)) / 2)

    G = H.to_bipartite_graph()
    for ci in nx.connected_components(G):
        Gi = G.subgraph(ci)
        key = {v: i for i, v in enumerate(nx.spectral_ordering(Gi))}.get
        ci_vertices, ci_edges = [
            sorted([v for v, d in Gi.nodes(data=True) if d["bipartite"] == j], key=key)
            for j in [0, 1]
        ]

        height = max(len(ci_vertices), len(ci_edges))

        stack(ci_vertices, 0, height)
        stack(ci_edges, 1, height)

        offset += height + spacing

    return pos

llm

llm_client

LLMClient

Language Model interface that supports synchronous, asynchronous, and streaming modes, and optionally, tool usage via function calls.

Source code in src/aeiva/llm/llm_client.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
class LLMClient:
    """
    Language Model interface that supports synchronous, asynchronous, and streaming modes,
    and optionally, tool usage via function calls.
    """

    def __init__(self, config: LLMGatewayConfig):
        self.config = config
        self.metrics = LLMUsageMetrics()
        self.logger = get_logger(__name__, level=config.llm_logging_level.upper())
        self._validate_config()

    def _validate_config(self):
        if not self.config.llm_api_key:
            raise ValueError("API key must be provided in the configuration.")

    @retry_sync(
        max_attempts=lambda self: self.config.llm_num_retries,
        backoff_factor=lambda self: self.config.llm_retry_backoff_factor,
        exceptions=(LLMGatewayError,),  # Catching LLMGatewayError
    )
    def generate(
        self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs
    ) -> str:
        try:
            max_iterations = MAX_TOOL_CALL_LOOP  # Prevent infinite loops
            iteration = 0

            while iteration < max_iterations:
                iteration += 1

                # Build parameters
                params = self._build_params(messages=messages, tools=tools, **kwargs)
                response = llm_completion(**params)
                self._update_metrics(response)
                response_message = response.choices[0].message

                tool_calls = response_message.tool_calls

                if tool_calls:
                    # Append assistant's tool call message
                    messages.append({"role": "assistant", "tool_calls": tool_calls})

                    for tool_call in tool_calls:
                        function_name = tool_call.function.name
                        function_args = json.loads(tool_call.function.arguments)
                        tool_call_id = tool_call.id
                        self.logger.info(f"Tool call id: {tool_call_id}")

                        try:
                            function_response = self.call_tool_sync(
                                api_name=function_name, function_name=function_name, params=function_args
                            )
                        except Exception as e:
                            self.logger.error(f"Error executing tool '{function_name}': {e}")
                            function_response = f"Error executing tool '{function_name}': {e}"

                        # Append the function response to messages
                        messages.append(
                            {
                                "tool_call_id": tool_call_id,
                                "role": "tool",
                                "name": function_name,
                                "content": str(function_response),
                            }
                        )
                    # Continue the loop to handle further function calls
                    continue
                else:
                    # Assistant provided a final response
                    messages.append({"role": "assistant", "content": response_message.content})
                    return response_message.content

            # If loop exceeds max iterations
            raise Exception("Maximum iterations reached without a final response.")

        except Exception as e:
            self.logger.error(f"LLM Gateway Error: {e}")
            raise llm_gateway_exception(e)

    @retry_async(
        max_attempts=lambda self: self.config.llm_num_retries,
        backoff_factor=lambda self: self.config.llm_retry_backoff_factor,
        exceptions=(LLMGatewayError,),  # Catching LLMGatewayError
    )
    async def agenerate(
        self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs
    ) -> str:
        try:
            max_iterations = MAX_TOOL_CALL_LOOP  # Prevent infinite loops
            iteration = 0

            while iteration < max_iterations:
                iteration += 1

                # Build parameters
                params = self._build_params(messages=messages, tools=tools, **kwargs)
                response = await llm_acompletion(**params)
                self._update_metrics(response)
                response_message = response.choices[0].message

                tool_calls = response_message.tool_calls

                if tool_calls:
                    # Append assistant's tool call message
                    messages.append({"role": "assistant", "tool_calls": tool_calls})

                    for tool_call in tool_calls:
                        function_name = tool_call.function.name
                        function_args = json.loads(tool_call.function.arguments)
                        tool_call_id = tool_call.id

                        try:
                            function_response = await self.call_tool(
                                api_name=function_name, function_name=function_name, params=function_args
                            )
                        except Exception as e:
                            self.logger.error(f"Error executing tool '{function_name}': {e}")
                            function_response = f"Error executing tool '{function_name}': {e}"

                        # Append the function response to messages
                        messages.append(
                            {
                                "tool_call_id": tool_call_id,
                                "role": "tool",
                                "name": function_name,
                                "content": str(function_response),
                            }
                        )
                    # Continue the loop to handle further function calls
                    continue
                else:
                    # Assistant provided a final response
                    messages.append({"role": "assistant", "content": response_message.content})
                    return response_message.content

            # If loop exceeds max iterations
            raise Exception("Maximum iterations reached without a final response.")

        except Exception as e:
            self.logger.error(f"LLM Asynchronous Generation Error: {e}")
            raise llm_gateway_exception(e)

    async def stream_generate(
        self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs
    ) -> AsyncGenerator[str, None]:
        try:
            max_iterations = MAX_TOOL_CALL_LOOP  # Prevent infinite loops
            iteration = 0

            while iteration < max_iterations:
                iteration += 1

                # Build parameters
                params = self._build_params(messages=messages, tools=tools, **kwargs)
                response_stream = await llm_acompletion(**params)

                # Prepare to collect the assistant's reply
                tool_calls = []  # Accumulator for tool calls
                full_delta_content = ''  # Accumulator for assistant's content

                # Collect streamed responses
                async for response in response_stream:
                    delta = response.choices[0].delta

                    # Collect assistant's content and yield it
                    if getattr(delta, 'content', None):
                        full_delta_content += delta.content
                        yield delta.content

                    # Check for tool calls in the delta
                    if getattr(delta, 'tool_calls', None):
                        tc_chunk_list = delta.tool_calls
                        for tc_chunk in tc_chunk_list:
                            index = tc_chunk.index
                            # Ensure tool_calls list is large enough
                            while len(tool_calls) <= index:
                                tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
                            tc = tool_calls[index]

                            if getattr(tc_chunk, 'id', None):
                                tc["id"] += tc_chunk.id
                            if getattr(tc_chunk.function, 'name', None):
                                tc["function"]["name"] += tc_chunk.function.name
                            if getattr(tc_chunk.function, 'arguments', None):
                                tc["function"]["arguments"] += tc_chunk.function.arguments

                # After initial streaming, check if there are tool calls
                if tool_calls:
                    # Append the assistant's tool_call message to messages
                    messages.append({"role": "assistant", "tool_calls": tool_calls})

                    # Process each tool_call
                    available_functions = [tool["function"]["name"] for tool in tools]
                    for tool_call in tool_calls:
                        function_name = tool_call["function"]["name"]
                        if function_name not in available_functions:
                            # Handle error if function not found
                            yield f"Function {function_name} does not exist."
                            return
                        # Call the function with arguments
                        try:
                            function_args = json.loads(tool_call["function"]["arguments"])
                        except json.JSONDecodeError as e:
                            self.logger.error(f"Error decoding function arguments: {e}")
                            function_args = {}

                        try:
                            function_response = await self.call_tool(
                                api_name=function_name, function_name=function_name, params=function_args
                            )
                        except Exception as e:
                            self.logger.error(f"Error executing tool '{function_name}': {e}")
                            function_response = f"Error executing tool '{function_name}': {e}"

                        # Append the function's response to messages
                        messages.append(
                            {
                                "tool_call_id": tool_call['id'],
                                "role": "tool",
                                "name": function_name,
                                "content": str(function_response),
                            }
                        )
                    # Continue the loop to handle further function calls
                    continue
                else:
                    # No tool calls, streaming is complete
                    messages.append({"role": "assistant", "content": full_delta_content})
                    return  # Exit the loop

            # If loop exceeds max iterations
            yield "Maximum iterations reached without a final response."

        except Exception as e:
            self.logger.error(f"Streaming LLM Gateway Error: {e}")
            yield "An error occurred during streaming."

    def call_tool_via_server(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise
        """Calls the API via FastAPI server."""
        url = f"http://localhost:8000/api/{api_name}/{function_name}"
        self.logger.info(f"Calling {api_name} with params: {params}")
        response = requests.get(url, params=params)
        if response.status_code == 200:
            json_response = response.json()
            if "result" in json_response:
                return str(json_response["result"])
            else:
                return f"Error from API: {json_response.get('error', 'Unknown error')}"
        else:
            return f"HTTP Error {response.status_code}: {response.text}"

    async def call_tool(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise
        """Calls the API via action module."""
        tool = Tool(api_name)
        return await tool.aexecute(params)

    def call_tool_sync(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise
        """Calls the API via action module."""
        tool = Tool(api_name)
        return tool.execute(params)

    def _build_params(
        self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs
    ) -> Dict[str, Any]:
        params = {
            "model": self.config.llm_model_name,
            "messages": messages,
            "api_key": self.config.llm_api_key,
            "temperature": self.config.llm_temperature,
            "top_p": self.config.llm_top_p,
            "max_tokens": self.config.llm_max_output_tokens,
            "timeout": self.config.llm_timeout,
        }
        params.update(self.config.llm_additional_params)
        params.update(kwargs)

        # Check if the model supports function calling
        if tools and supports_function_calling(self.config.llm_model_name):
            params["tools"] = tools
            params["tool_choice"] = "auto"

        return params

    def _update_metrics(self, response: Any, log: bool = False):  # Note: log is False by default. Adjust according to the need.
        usage = getattr(response, "usage", {})
        self.metrics.add_tokens(
            prompt_tokens=getattr(usage, "prompt_tokens", 0),
            completion_tokens=getattr(usage, "completion_tokens", 0),
        )
        self.metrics.add_cost(getattr(usage, "cost", 0.0))
        if log:
            self.logger.info(
                f"Tokens used: {self.metrics.total_tokens}, Cost: ${self.metrics.total_cost:.4f}"
            )

    def __call__(
        self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs
    ) -> Any:
        if self.config.llm_use_async:
            if self.config.llm_stream:
                return self.stream_generate(messages, tools=tools, **kwargs)
            else:
                return self.agenerate(messages, tools=tools, **kwargs)
        else:
            if self.config.llm_stream:
                # OpenAI's API does not support synchronous streaming; streaming must be async
                raise NotImplementedError("Synchronous streaming is not supported.")
            else:
                return self.generate(messages, tools=tools, **kwargs)
call_tool(api_name, function_name, params) async

Calls the API via action module.

Source code in src/aeiva/llm/llm_client.py
279
280
281
282
async def call_tool(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise
    """Calls the API via action module."""
    tool = Tool(api_name)
    return await tool.aexecute(params)
call_tool_sync(api_name, function_name, params)

Calls the API via action module.

Source code in src/aeiva/llm/llm_client.py
284
285
286
287
def call_tool_sync(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise
    """Calls the API via action module."""
    tool = Tool(api_name)
    return tool.execute(params)
call_tool_via_server(api_name, function_name, params)

Calls the API via FastAPI server.

Source code in src/aeiva/llm/llm_client.py
265
266
267
268
269
270
271
272
273
274
275
276
277
def call_tool_via_server(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise
    """Calls the API via FastAPI server."""
    url = f"http://localhost:8000/api/{api_name}/{function_name}"
    self.logger.info(f"Calling {api_name} with params: {params}")
    response = requests.get(url, params=params)
    if response.status_code == 200:
        json_response = response.json()
        if "result" in json_response:
            return str(json_response["result"])
        else:
            return f"Error from API: {json_response.get('error', 'Unknown error')}"
    else:
        return f"HTTP Error {response.status_code}: {response.text}"

llm_gateway_config

LLMGatewayConfig dataclass

Bases: BaseConfig

Configuration for the Language Model (LLM).

Source code in src/aeiva/llm/llm_gateway_config.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@dataclass
class LLMGatewayConfig(BaseConfig):
    """
    Configuration for the Language Model (LLM).
    """

    llm_model_name: Optional[str] = field(
        default='gpt-4',
        metadata={"help": "The name of the LLM model to use (e.g., 'gpt-4', 'gpt-3.5-turbo')."}
    )
    llm_api_key: Optional[str] = field(
        default=None,
        metadata={"help": "The API key for authentication with the LLM provider."}
    )
    llm_base_url: Optional[str] = field(
        default=None,
        metadata={"help": "The base URL for API requests to the LLM provider."}
    )
    llm_api_version: Optional[str] = field(
        default=None,
        metadata={"help": "The version of the LLM API to use."}
    )
    llm_embedding_model: Optional[str] = field(
        default=None,
        metadata={"help": "The embedding model to use for tasks requiring embeddings."}
    )
    llm_timeout: Optional[int] = field(
        default=30,
        metadata={"help": "The timeout in seconds for API requests."}
    )
    llm_max_input_tokens: Optional[int] = field(
        default=4096,
        metadata={"help": "The maximum number of input tokens allowed in a request."}
    )
    llm_max_output_tokens: Optional[int] = field(
        default=1024,
        metadata={"help": "The maximum number of output tokens generated by the LLM."}
    )
    llm_temperature: Optional[float] = field(
        default=0.7,
        metadata={"help": "Sampling temperature for response variability (range: 0.0 - 1.0)."}
    )
    llm_top_p: Optional[float] = field(
        default=0.9,
        metadata={"help": "Nucleus sampling probability for token selection (range: 0.0 - 1.0)."}
    )
    llm_num_retries: Optional[int] = field(
        default=3,
        metadata={"help": "The number of times to retry failed API requests."}
    )
    llm_retry_backoff_factor: Optional[float] = field(
        default=0.5,
        metadata={"help": "Factor for exponential backoff between retries."}
    )
    llm_retry_on_status: Optional[Tuple[int, ...]] = field(
        default=(429, 500, 502, 503, 504),
        metadata={"help": "HTTP status codes that should trigger a retry."}
    )
    llm_use_async: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use asynchronous API calls."}
    )
    llm_stream: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to enable streaming responses from the LLM."}
    )
    llm_logging_level: Optional[str] = field(
        default='INFO',
        metadata={"help": "Logging level for the LLM module (e.g., 'DEBUG', 'INFO')."}
    )
    llm_additional_params: Optional[Dict[str, Any]] = field(
        default_factory=dict,
        metadata={"help": "Additional parameters to pass to the LLM API."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Load API keys from the configuration file if not provided
        if not self.llm_api_key:
            self.load_api_key()

    def load_api_key(self):
        config_path = os.path.join(os.path.dirname(__file__), '../../../configs/llm_api_keys.yaml')
        try:
            with open(config_path, 'r') as f:
                keys = yaml.safe_load(f)
                self.llm_api_key = keys.get('openai_api_key')
        except FileNotFoundError:
            raise FileNotFoundError('API keys file not found.')
        except Exception as e:
            raise e

    def to_dict(self):
        return {
            key: ('******' if key == 'llm_api_key' and value else value)
            for key, value in self.__dict__.items()
            if not key.startswith('_')
        }

llm_gateway_exceptions

LLMGatewayError

Bases: Exception

Unified exception class for all LLM-related errors.

Source code in src/aeiva/llm/llm_gateway_exceptions.py
25
26
27
28
29
30
class LLMGatewayError(Exception):
    """Unified exception class for all LLM-related errors."""

    def __init__(self, message: str, original_exception: Exception = None):
        super().__init__(message)
        self.original_exception = original_exception

llm_gateway_exception(e)

Converts a litellm exception to a unified LLMGatewayError.

Source code in src/aeiva/llm/llm_gateway_exceptions.py
55
56
57
58
59
def llm_gateway_exception(e: Exception) -> LLMGatewayError:
    """Converts a litellm exception to a unified LLMGatewayError."""
    exception_type = type(e)
    mapped_exception = LITELLM_EXCEPTION_MAP.get(exception_type, LLMGatewayError)
    return mapped_exception(str(e), original_exception=e)

llm_usage_metrics

LLMUsageMetrics

Tracks metrics such as token usage and cost.

Source code in src/aeiva/llm/llm_usage_metrics.py
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class LLMUsageMetrics:
    """
    Tracks metrics such as token usage and cost.
    """
    def __init__(self):
        self.total_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.total_cost = 0.0

    def add_tokens(self, prompt_tokens: int, completion_tokens: int):
        self.prompt_tokens += prompt_tokens
        self.completion_tokens += completion_tokens
        self.total_tokens += prompt_tokens + completion_tokens

    def add_cost(self, cost: float):
        self.total_cost += cost

model

macaw_model

LlamaAttention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper

Source code in src/aeiva/model/macaw_model.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.max_position_embeddings = config.max_position_embeddings

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # [bsz, nh, t, hd]

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask
            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

LlamaDecoderLayer

Bases: Module

Source code in src/aeiva/model/macaw_model.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(config=config)
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
        )
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)

Parameters:

Name Type Description Default
hidden_states `torch.FloatTensor`

input to the layer of shape (batch, seq_len, embed_dim)

required
attention_mask `torch.FloatTensor`, *optional*

attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values.

None
output_attentions `bool`, *optional*

Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

False
use_cache `bool`, *optional*

If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).

False
past_key_value `Tuple(torch.FloatTensor)`, *optional*

cached past key and value projection states

None
Source code in src/aeiva/model/macaw_model.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
    """
    Args:
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
            (see `past_key_values`).
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    """

    residual = hidden_states

    hidden_states = self.input_layernorm(hidden_states)

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )
    hidden_states = residual + hidden_states

    # Fully Connected
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return outputs

LlamaModel

Bases: LlamaPreTrainedModel

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [LlamaDecoderLayer]

Parameters:

Name Type Description Default
config LlamaConfig

LlamaConfig

required
Source code in src/aeiva/model/macaw_model.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

LlamaRMSNorm

Bases: Module

Source code in src/aeiva/model/macaw_model.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * hidden_states
__init__(hidden_size, eps=1e-06)

LlamaRMSNorm is equivalent to T5LayerNorm

Source code in src/aeiva/model/macaw_model.py
303
304
305
306
307
308
309
def __init__(self, hidden_size, eps=1e-6):
    """
    LlamaRMSNorm is equivalent to T5LayerNorm
    """
    super().__init__()
    self.weight = nn.Parameter(torch.ones(hidden_size))
    self.variance_epsilon = eps

MM_LLMs_Config

Bases: PretrainedConfig

Source code in src/aeiva/model/macaw_model.py
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
class MM_LLMs_Config(PretrainedConfig):
    model_type = 'mm_llms'
    is_composition = True

    def __init__(self, n_frames=6, attention_heads=8, image_conv_kernel=48, image_conv_stride=36, 
    video_conv_kernel=36, video_conv_stride=30, audio_conv_kernel=240, audio_conv_stride=220,
    clip_config=None, whisper_config=None, llm_config=None, **kwargs):

        self.image_config = clip_config
        self.audio_config = whisper_config
        self.llm_config = llm_config
        self.n_frames = n_frames
        self.attention_heads = attention_heads
        self.image_conv_kernel = image_conv_kernel
        self.image_conv_stride = image_conv_stride
        self.video_conv_kernel = video_conv_kernel
        self.video_conv_stride = video_conv_stride
        self.audio_conv_kernel = audio_conv_kernel
        self.audio_conv_stride = audio_conv_stride

        self.hidden_size = max(llm_config.hidden_size, clip_config.projection_dim, whisper_config.d_model, clip_config.projection_dim)

        super().__init__(**kwargs)

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output["image_config"] = self.image_config.to_dict()
        output["audio_config"] = self.audio_config.to_dict()
        output['llm_config'] = self.llm_config.to_dict()
        output['n_frames'] = self.n_frames
        output['attention_heads'] = self.attention_heads
        output['image_conv_kernel'] = self.image_conv_kernel
        output['image_conv_stride'] = self.image_conv_stride
        output['video_conv_kernel'] = self.video_conv_kernel
        output['video_conv_stride'] = self.video_conv_stride
        output['audio_conv_kernel'] = self.audio_conv_kernel
        output['audio_conv_stride'] = self.audio_conv_stride
        output['hidden_size'] = self.hidden_size
        output["model_type"] = self.__class__.model_type
        return output
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        clip_config = CLIPConfig.from_dict(config_dict['image_config'])
        whisper_config = WhisperConfig.from_dict(config_dict['audio_config'])
        llm_config = LlamaConfig.from_dict(config_dict['llm_config'])

        return cls(clip_config=clip_config, whisper_config=whisper_config, llm_config=llm_config, **kwargs)
to_dict()

Serializes this instance to a Python dictionary. Override the default [~PretrainedConfig.to_dict].

Returns:

Type Description

Dict[str, any]: Dictionary of all the attributes that make up this configuration instance,

Source code in src/aeiva/model/macaw_model.py
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def to_dict(self):
    """
    Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

    Returns:
        `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
    """
    output = copy.deepcopy(self.__dict__)
    output["image_config"] = self.image_config.to_dict()
    output["audio_config"] = self.audio_config.to_dict()
    output['llm_config'] = self.llm_config.to_dict()
    output['n_frames'] = self.n_frames
    output['attention_heads'] = self.attention_heads
    output['image_conv_kernel'] = self.image_conv_kernel
    output['image_conv_stride'] = self.image_conv_stride
    output['video_conv_kernel'] = self.video_conv_kernel
    output['video_conv_stride'] = self.video_conv_stride
    output['audio_conv_kernel'] = self.audio_conv_kernel
    output['audio_conv_stride'] = self.audio_conv_stride
    output['hidden_size'] = self.hidden_size
    output["model_type"] = self.__class__.model_type
    return output

WhisperEncoder

Bases: WhisperPreTrainedModel

Transformer encoder consisting of config.encoder_layers self attention layers. Each layer is a [WhisperEncoderLayer].

Parameters:

Name Type Description Default
config WhisperConfig

WhisperConfig

required
Source code in src/aeiva/model/macaw_model.py
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
class WhisperEncoder(WhisperPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`WhisperEncoderLayer`].

    Args:
        config: WhisperConfig
    """

    def __init__(self, config: WhisperConfig):
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.num_mel_bins = config.num_mel_bins
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_source_positions
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)

        self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)

        self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layer_norm = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def _freeze_parameters(self):
        for param in self.parameters():
            param.requires_grad = False
        self._requires_grad = False

    def get_input_embeddings(self) -> nn.Module:
        return self.conv1

    def set_input_embeddings(self, value: nn.Module):
        self.conv1 = value

    def forward(
        self,
        input_features,
        attention_mask=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Args:
            input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
            attention_mask (`torch.Tensor`)`, *optional*):
                Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
                but it is not used. By default the silence in the input log mel spectrogram are ignored.
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            assert head_mask.size()[0] == (
                len(self.layers)
            ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        None,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        None,
                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        hidden_states = self.layer_norm(hidden_states)
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )
forward(input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)

Parameters:

Name Type Description Default
input_features `torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`

Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a .flac or .wav audio file into an array of type List[float] or a numpy.ndarray, e.g. via the soundfile library (pip install soundfile). To prepare the array into input_features, the [AutoFeatureExtractor] should be used for extracting the mel features, padding and conversion into a tensor of type torch.FloatTensor. See [~WhisperFeatureExtractor.__call__]

required
attention_mask `torch.Tensor`)`, *optional*

Whisper does not support masking of the input_features, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored.

None
head_mask `torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*

Mask to nullify selected heads of the attention modules. Mask values selected in [0, 1]:

  • 1 indicates the head is not masked,
  • 0 indicates the head is masked.
None
output_attentions `bool`, *optional*

Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

None
output_hidden_states `bool`, *optional*

Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

None
return_dict `bool`, *optional*

Whether or not to return a [~utils.ModelOutput] instead of a plain tuple.

None
Source code in src/aeiva/model/macaw_model.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
def forward(
    self,
    input_features,
    attention_mask=None,
    head_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    r"""
    Args:
        input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
            Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
            obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
            `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
            `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
            and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
        attention_mask (`torch.Tensor`)`, *optional*):
            Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
            but it is not used. By default the silence in the input log mel spectrogram are ignored.
        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
            for more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
    """
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    inputs_embeds = nn.functional.gelu(self.conv1(input_features))
    inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

    inputs_embeds = inputs_embeds.permute(0, 2, 1)
    embed_pos = self.embed_positions.weight

    hidden_states = inputs_embeds + embed_pos
    hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

    encoder_states = () if output_hidden_states else None
    all_attentions = () if output_attentions else None

    # check if head_mask has a correct number of layers specified if desired
    if head_mask is not None:
        assert head_mask.size()[0] == (
            len(self.layers)
        ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."

    for idx, encoder_layer in enumerate(self.layers):
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)
        # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
        dropout_probability = random.uniform(0, 1)
        if self.training and (dropout_probability < self.layerdrop):  # skip the layer
            layer_outputs = (None, None)
        else:
            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(encoder_layer),
                    hidden_states,
                    None,
                    (head_mask[idx] if head_mask is not None else None),
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    None,
                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                    output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]

        if output_attentions:
            all_attentions = all_attentions + (layer_outputs[1],)

    hidden_states = self.layer_norm(hidden_states)
    if output_hidden_states:
        encoder_states = encoder_states + (hidden_states,)

    if not return_dict:
        return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
    return BaseModelOutput(
        last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
    )

rotate_half(x)

Rotates half the hidden dims of the input.

Source code in src/aeiva/model/macaw_model.py
76
77
78
79
80
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

macaw_model_old

This script contains the implementation of the MACAW model. MACAW is a multimodal transformer model that combines the CLIP and Whisper models.

Author: Bang Liu Date: 2023-06-22

References: - Macaw-LLM code repository: https://github.com/lyuchenyang/Macaw-LLM/blob/main/modeling.py

LlamaAttention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper

Source code in src/aeiva/model/macaw_model_old.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.max_position_embeddings = config.max_position_embeddings  # !!! I want to change this variable name.

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        # By placing the num_heads dimension as the second dimension, it allows for 
        # efficient batched matrix operations (e.g., matrix multiplication in attention computation) 
        # across all the heads. It is basically a data layout optimization for computational efficiency 
        # in the context of multi-head attention.
        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]  # the shape is [batch_size, num_heads, seq_len, head_dim], so -2 dimension is 'seq_len'
        if past_key_value is not None:  
            # If past_key_value is not None, this means the model is being used in an autoregressive setting, 
            # where the past key-value pairs are given to the current step.
            # past_key_value[0] refers to the previously computed key states,
            # past_key_value[1] refers to the previously computed value states.
            # The shape of past_key_value[0] and past_key_value[1] is [batch_size, num_heads, seq_len, head_dim].
            kv_seq_len += past_key_value[0].shape[-2]  # + past seq_len

        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask
            # This following line is ensuring numerical stability. It caps the minimum value of the attention weights
            # to be the minimum finite representable number for the data type of attn_weights. This avoids 
            # potential issues with underflow when these weights are later passed through the softmax function.
            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

        # upcast attention to fp32
        # This is done to prevent numerical instability that can occur
        # during operations on very small numbers or very large numbers.
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) # self.hidden_size is equivalent to self.num_heads * self.head_dim

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

LlamaDecoderLayer

Bases: Module

Source code in src/aeiva/model/macaw_model_old.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(config=config)
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
        )
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)

Parameters:

Name Type Description Default
hidden_states `torch.FloatTensor`

input to the layer of shape (batch, seq_len, embed_dim)

required
attention_mask `torch.FloatTensor`, *optional*

attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values.

None
output_attentions `bool`, *optional*

Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

False
use_cache `bool`, *optional*

If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).

False
past_key_value `Tuple(torch.FloatTensor)`, *optional*

cached past key and value projection states

None
Source code in src/aeiva/model/macaw_model_old.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
    """
    Args:
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
            (see `past_key_values`).
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    """

    residual = hidden_states

    hidden_states = self.input_layernorm(hidden_states)

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )
    hidden_states = residual + hidden_states

    # Fully Connected
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return outputs

LlamaModel

Bases: LlamaPreTrainedModel

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [LlamaDecoderLayer]

Parameters:

Name Type Description Default
config LlamaConfig

LlamaConfig

required
Source code in src/aeiva/model/macaw_model_old.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        # embedding layer, stacked decoder layers, and layer normalization in llama.
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)

        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])

        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Gradient checkpointing is a technique to reduce the memory usage when training deep neural networks.
        # In deep learning, when you perform backpropagation to compute gradients and update the model parameters,
        # you need to store the intermediate activations from the forward pass, so you can use them in the backward pass. 
        # For large models or long sequences, this can consume a lot of memory.
        # 
        # Gradient checkpointing addresses this by not storing all the intermediate activations in memory during the forward pass. 
        # Instead, it stores only a subset of the activations, and recomputes the rest during the backward pass as needed. 
        # This trades off computation time (because you need to recompute some values) for memory usage.
        # 
        # This technique is particularly useful when training large models that would otherwise not fit into GPU memory. 
        # However, it can slow down training because of the extra computation.
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:  # seq_len > 1
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        # set output and cache flags
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # prepare input_ids/inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # prepare attention mask and other parameters for decoder layers
        past_key_values_length = 0
        seq_length_with_past = seq_length

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # forward through all decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:
                # define the function for gradient checkpointing
                # in checkpointing, we need to create a custom function for the forward pass 
                # (the custom_forward function in your code) and then using the 
                # torch.utils.checkpoint.checkpoint function to apply this custom function 
                # with gradient checkpointing.
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions, None)  # None for past_key_value
                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None

        # output the hidden states, the self attentions and the cache (if needed)
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

LlamaRMSNorm

Bases: Module

Source code in src/aeiva/model/macaw_model_old.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        The overall effect of this layer is to ensure that,
        for each feature in the hidden_states,
        the activations have zero mean and unit variance across the batch.
        This can make the training process more stable and faster.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))  # trainable parameter for affine transformation
        self.variance_epsilon = eps  # for numerical stability

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * hidden_states
__init__(hidden_size, eps=1e-06)

LlamaRMSNorm is equivalent to T5LayerNorm The overall effect of this layer is to ensure that, for each feature in the hidden_states, the activations have zero mean and unit variance across the batch. This can make the training process more stable and faster.

Source code in src/aeiva/model/macaw_model_old.py
317
318
319
320
321
322
323
324
325
326
327
def __init__(self, hidden_size, eps=1e-6):
    """
    LlamaRMSNorm is equivalent to T5LayerNorm
    The overall effect of this layer is to ensure that,
    for each feature in the hidden_states,
    the activations have zero mean and unit variance across the batch.
    This can make the training process more stable and faster.
    """
    super().__init__()
    self.weight = nn.Parameter(torch.ones(hidden_size))  # trainable parameter for affine transformation
    self.variance_epsilon = eps  # for numerical stability

LlamaRotaryEmbedding

Bases: Module

Rotary embedding described in: https://arxiv.org/pdf/2104.09864.pdf. It is used to modulate the position information in the input embeddings. Llama used rotary embedding.

Source code in src/aeiva/model/macaw_model_old.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class LlamaRotaryEmbedding(torch.nn.Module):
    """
    Rotary embedding described in: https://arxiv.org/pdf/2104.09864.pdf.
    It is used to modulate the position information in the input embeddings.
    Llama used rotary embedding.
    """
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        # Compute the inverse frequencies, which will be used to modulate the position information
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        # The register_buffer() function is used in PyTorch to register a tensor that is not a parameter,
        # but you still want it to be a part of the model's state. It's used for tensors that should
        # have their state saved in the model's state_dict and should be moved to the device with the rest of the model.
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        # max_position_embeddings: max sequence length that this model might ever be used with
        self.max_seq_len_cached = max_position_embeddings

        # Compute the positional encodings (both cos and sin parts)
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)

        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x, seq_len=None):
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        # x.shape: [batch_size, num_attention_heads, sequence_length, head_size].
        # The forward function then outputs two tensors, each of which is a sin or cos embedding representation of the input x. 
        # Both output tensors will have a shape of [1, 1, sequence_length, head_size].
        # NOTE: Only the dtype and device attributes of x are relevant here. The values are not used.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

MM_LLMs

Bases: PreTrainedModel

This is the multimodal language model that combines CLIP and Whisper encoders with a language model. We need a config file to specify the multimodal encoder configurations.

Source code in src/aeiva/model/macaw_model_old.py
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
class MM_LLMs(PreTrainedModel):
    """
    This is the multimodal language model that combines CLIP and Whisper encoders with a language model.
    We need a config file to specify the multimodal encoder configurations.
    """
    def __init__(self, config):
        super().__init__(config)
        # multimodal config
        self.config = config

        # multimodal encoders
        self.image_encoder = CLIPModel(config.image_config)  # NOTE: here they use CLIP for both image and video.
        self.video_encoder = CLIPModel(config.image_config)
        self.audio_encoder = WhisperModel(config.audio_config)
        self.llm = LlamaForCausalLM(config.llm_config)

        # video temporal position embedding layer
        self.temporal_position_embeddings = nn.Embedding(
            config.n_frames, 
            config.image_config.projection_dim)

        # multimodal attention layers for mapping multimodal features to the same space
        attn_dropout = 0.1
        is_add_bias_kv = True
        is_add_zero_attn = True
        self.temporal_self_attention = nn.MultiheadAttention(config.image_config.projection_dim,
                                                             config.attention_heads,
                                                             dropout=attn_dropout,
                                                             add_bias_kv=is_add_bias_kv,
                                                             add_zero_attn=is_add_zero_attn)
        self.video_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, 
                                                             config.attention_heads,
                                                             dropout=attn_dropout,
                                                             add_bias_kv=is_add_bias_kv,
                                                             add_zero_attn=is_add_zero_attn)
        self.audio_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, 
                                                             config.attention_heads,
                                                             dropout=attn_dropout,
                                                             add_bias_kv=is_add_bias_kv,
                                                             add_zero_attn=is_add_zero_attn)
        self.image_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, 
                                                             config.attention_heads,
                                                             dropout=attn_dropout,
                                                             add_bias_kv=is_add_bias_kv,
                                                             add_zero_attn=is_add_zero_attn)

        # multimodal projection layers for mapping multimodal features to the same space
        self.transform_video_to_hidden = nn.Linear(config.image_config.projection_dim, 
                                                   config.llm_config.hidden_size)
        self.transform_audio_to_hidden = nn.Linear(config.audio_config.d_model, 
                                                   config.llm_config.hidden_size)
        self.transform_image_to_hidden = nn.Linear(config.image_config.projection_dim, 
                                                   config.llm_config.hidden_size)

        self.project_image = nn.Conv1d(config.image_config.projection_dim, config.image_config.projection_dim, 
        kernel_size=48, stride=36)
        self.project_video = nn.Conv1d(config.image_config.projection_dim, config.image_config.projection_dim, 
        kernel_size=36, stride=30)
        self.project_audio = nn.Conv1d(config.audio_config.d_model, config.audio_config.d_model, 
        kernel_size=240, stride=220)

        # multimodal fusion layers
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.layer_norm = nn.LayerNorm(config.image_config.projection_dim)
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        self.gelu = nn.GELU()
        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()

        self.loss_fct = CrossEntropyLoss()

        self.init_weights()

    def forward(self, inputs=None):
        # """
        # :param inputs:
        #             video_frames: (B x F)
        #             audios: B x 1
        #             images: B x 1
        #             input_ids: B x L
        #             labels: B x L
        #
        # :return: the output of the language model LlamaForCausalLM.
        # """
        text_embeddings, attention_mask, labels = self.prepare_inputs_for_generation(inputs)

        if 'inference' in inputs and inputs['inference'] is True:
            # generate_ids = self.llm.generate(input_ids=inputs['input_ids'], inputs_embeds=text_embeddings, max_new_tokens=128)
            # generate_ids = self.llm.generate(inputs_embeds=text_embeddings, max_new_tokens=128)

            # !!! The code below will possibly trigger an error in : https://github.com/microsoft/DeepSpeed/issues/3156 (the solution only partially resolves the bug for me)
            generate_ids = self.llm.generate(
                inputs_embeds=text_embeddings, max_new_tokens=128, eos_token_id=2, bos_token_id=1, pad_token_id=32006  # !!! revise later. use config constants instead.
                )
            return generate_ids
        outputs = self.llm(inputs_embeds=text_embeddings, attention_mask=attention_mask, labels=labels)

        return outputs

    def prepare_inputs_for_generation(self, inputs):
        """
        The purpose of this method is to integrate the different modalities into the text embeddings 
        and prepare the associated attention mask and labels for the language model, so the model can 
        generate text conditioned on all the input modalities.

        inputs is a dictionary containing the following keys: (!!! my hypothesis)
            video_frames: (B x F)
            audios: B x 1
            images: B x 1
            input_ids: B x L
            attention_mask: B x L
            labels: B x L
            video_starts: B x 1
            video_ends: B x 1
            audio_starts: B x 1
            audio_ends: B x 1
            image_starts: B x 1
            image_ends: B x 1
            inference: True/False
        """
        # get multimodal embeddings
        image_features = self.encode_image(inputs['images']) if inputs['images'] is not None else None
        audio_features = self.encode_audio(inputs['audios']) if inputs['audios'] is not None else None
        video_features = self.encode_video(inputs['videos']) if inputs['videos'] is not None else None
        embed_tokens = self.llm.model.embed_tokens


        # for debug !!!!!!
        # Find maximum id in input_ids
        max_id = torch.max(inputs['input_ids'])
        print(f"Max ID in input_ids: {max_id.item()}")

        # Get vocab size from embedding layer
        vocab_size = embed_tokens.num_embeddings
        print(f"Vocabulary size: {vocab_size}")



        text_embeddings = embed_tokens(inputs['input_ids'])

        token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(
            text_embeddings.size(0), 1, 1).transpose(0, 1)

        # ignore_num seems to be a counter that tracks the total size (or length) of the 
        # multimodal input segments (video, audio, image) added to the original text inputs.
        ingore_num = 0

        # project and merge video features to the same space as text embeddings
        if video_features is not None:
            # get video starts and ends embeddings
            video_starts = embed_tokens(inputs['video_starts']).unsqueeze(1)
            video_ends = embed_tokens(inputs['video_ends']).unsqueeze(1)

            # project video features to the same space as text embeddings
            video_features = self.transform_video_to_hidden(video_features)

            video_features = self.video_align_attention(
                video_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()

            # concatenate video starts, video features, and video ends embeddings
            video_inputs = torch.cat([torch.cat([video_starts, video_features], dim=1), video_ends], dim=1)

            # concatenate video inputs to the original text embeddings
            # NOTE: the first token of text_embeddings keeps at the same position
            text_embeddings = torch.cat([torch.cat([text_embeddings[:, 0, :].unsqueeze(1), video_inputs], dim=1), text_embeddings[:, 1:, :]], dim=1)

            ingore_num += (video_inputs.size(1))

        # project and merge audio features to the same space as text embeddings
        if audio_features is not None:
            # get audio starts and ends embeddings
            audio_starts = embed_tokens(inputs['audio_starts']).unsqueeze(1)
            audio_ends = embed_tokens(inputs['audio_ends']).unsqueeze(1)

            # project audio features to the same space as text embeddings
            audio_features = self.project_audio(audio_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()
            audio_features = self.transform_audio_to_hidden(audio_features)
            # mean pooling
            # audio_features = torch.sum(audio_features, dim=1) / audio_features.size(1) 
            # audio_features = audio_features.unsqueeze(1)
            audio_features = self.audio_align_attention(
                audio_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()

            # concatenate audio starts, audio features, and audio ends embeddings
            audio_inputs = torch.cat([torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)

            # concatenate audio inputs to the original text embeddings
            text_embeddings = torch.cat(
                [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],
                dim=1)

            ingore_num += (audio_inputs.size(1))

        # project and merge image features to the same space as text embeddings
        if image_features is not None:
            # get image starts and ends embeddings
            image_starts = embed_tokens(inputs['image_starts']).unsqueeze(1)
            image_ends = embed_tokens(inputs['image_ends']).unsqueeze(1)

            # project image features to the same space as text embeddings
            image_features = self.project_image(image_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()
            image_features = self.transform_image_to_hidden(image_features)
            image_features = self.image_align_attention(
                image_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()

            # concatenate image starts, image features, and image ends embeddings
            image_inputs = torch.cat([torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)

            # concatenate image inputs to the original text embeddings
            text_embeddings = torch.cat(
                [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1), 
                text_embeddings[:, 1:, :]], dim=1)

            ingore_num += (image_inputs.size(1))

        if 'attention_mask' in inputs:
            # increase the length of attention mask by adding the length of multimodal inputs
            attention_mask = torch.tensor([1]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)  # (B X ignore_num)
            attention_mask = torch.cat([attention_mask, inputs['attention_mask']], dim=1)
        else:
            attention_mask = None

        if 'labels' in inputs and inputs['labels'] is not None:
            # increase the length of labels by adding the length of labels
            # we use -100 to ignore the loss of labels in multimodal inputs
            # !!! we can replace -100 by config constants to make the code better

            # since the tokens corresponding to the image_inputs, audio_inputs, and video_inputs are not part of the original text 
            # and don't have corresponding labels in the true text sequence, their labels are set to -100. This ensures that 
            # the model's predictions for these tokens don't affect the loss and, consequently, the gradients and the model's subsequent learning.
            labels = torch.tensor([-100]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)
            labels = torch.cat([labels, inputs['labels']], dim=1)
        else:
            labels = None

        # text_embeddings: (batch_size, sequence_length + ingore_num, embedding_dim)
        # attention_mask: (batch_size, sequence_length + ingore_num). 1 denotes we should attend to, and 0 denotes we should not attend to the token.
        # labels: (batch_size, sequence_length + ingore_num). -100 denotes we should ignore the token.
        return text_embeddings, attention_mask, labels

    def encode_video(self, videos):
        """
        Encode video features to video embeddings.

        Args:
            videos: (batch_size, n_frames, n_channels, height, width)

        Returns:
            video_embeddings: (batch_size, n_frames, embedding_dim)
        """
        # simple image encoding without temporal embedding and self attention
        # Reference: https://huggingface.co/docs/transformers/model_doc/clip
        videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1))  # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width)) 
        video_outputs = self.video_encoder.get_image_features(videos)  # image_features (torch.FloatTensor of shape (batch_size * n_frames, output_dim)
        video_features = video_outputs
        temporal_pos = torch.tensor(
            [[i for i in range(self.config.n_frames)] 
            for j in range(videos.size(0) // self.config.n_frames)],
            dtype=torch.int, device=video_features.device).view(-1)  # 2d indices to 1d indices, shape: (batch_size * n_frames)

        frame_temporal_pos_embed = self.temporal_position_embeddings(temporal_pos)

        video_features = (video_features + frame_temporal_pos_embed).view(
            videos.size(0) // self.config.n_frames, self.config.n_frames, -1)  # (batch_size, n_frames, output_dim)

        video_features = video_features.transpose(0, 1).contiguous()
        # nn.MultiheadAttention takes query, key, value as inputs. Their shapes are (sequence_length, batch_size, embedding_dim).
        # The outputs are two elements: attn_output of shape (sequence_length, batch_size, embedding_dim), and attn_output_weights of shape (batch_size, sequence_length, sequence_length).
        self_attn_video_features = self.temporal_self_attention(video_features, video_features, video_features)[0]

        return self_attn_video_features.transpose(0, 1).contiguous() # (batch_size, n_frames, output_dim)

    def encode_video_long(self, videos):
        """
        Encode video features to video embeddings.

        Args:
            videos: (batch_size, n_frames, n_channels, height, width)

        Returns:
            video_embeddings: (batch_size, n_frames, embedding_dim)
        """
        # simple image encoding without temporal embedding and self attention
        videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1))  # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width))
        video_features = self.video_encoder.visual_projection(self.video_encoder.vision_model(videos)[0])[:, 1:, :]
        video_features = video_features.reshape(
            videos.size(0) // self.config.n_frames,
            self.config.n_frames * video_features.size(1),
            -1).contiguous()

        return video_features

    def encode_audio(self, audios):
        audio_features = self.audio_encoder.encoder(audios)
        return audio_features[0]

    def encode_image(self, images):
        # vision_outputs = self.image_encoder.get_image_features(images)
        # image_features = vision_outputs  # pooled_output
        # image_features = self.visual_projection(pooled_output)
        # image_features = image_features.unsqueeze(1)
        image_features = self.image_encoder.visual_projection(self.image_encoder.vision_model(images)[0])[:, 1:, :]
        return image_features
encode_video(videos)

Encode video features to video embeddings.

Parameters:

Name Type Description Default
videos

(batch_size, n_frames, n_channels, height, width)

required

Returns:

Name Type Description
video_embeddings

(batch_size, n_frames, embedding_dim)

Source code in src/aeiva/model/macaw_model_old.py
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
def encode_video(self, videos):
    """
    Encode video features to video embeddings.

    Args:
        videos: (batch_size, n_frames, n_channels, height, width)

    Returns:
        video_embeddings: (batch_size, n_frames, embedding_dim)
    """
    # simple image encoding without temporal embedding and self attention
    # Reference: https://huggingface.co/docs/transformers/model_doc/clip
    videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1))  # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width)) 
    video_outputs = self.video_encoder.get_image_features(videos)  # image_features (torch.FloatTensor of shape (batch_size * n_frames, output_dim)
    video_features = video_outputs
    temporal_pos = torch.tensor(
        [[i for i in range(self.config.n_frames)] 
        for j in range(videos.size(0) // self.config.n_frames)],
        dtype=torch.int, device=video_features.device).view(-1)  # 2d indices to 1d indices, shape: (batch_size * n_frames)

    frame_temporal_pos_embed = self.temporal_position_embeddings(temporal_pos)

    video_features = (video_features + frame_temporal_pos_embed).view(
        videos.size(0) // self.config.n_frames, self.config.n_frames, -1)  # (batch_size, n_frames, output_dim)

    video_features = video_features.transpose(0, 1).contiguous()
    # nn.MultiheadAttention takes query, key, value as inputs. Their shapes are (sequence_length, batch_size, embedding_dim).
    # The outputs are two elements: attn_output of shape (sequence_length, batch_size, embedding_dim), and attn_output_weights of shape (batch_size, sequence_length, sequence_length).
    self_attn_video_features = self.temporal_self_attention(video_features, video_features, video_features)[0]

    return self_attn_video_features.transpose(0, 1).contiguous() # (batch_size, n_frames, output_dim)
encode_video_long(videos)

Encode video features to video embeddings.

Parameters:

Name Type Description Default
videos

(batch_size, n_frames, n_channels, height, width)

required

Returns:

Name Type Description
video_embeddings

(batch_size, n_frames, embedding_dim)

Source code in src/aeiva/model/macaw_model_old.py
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
def encode_video_long(self, videos):
    """
    Encode video features to video embeddings.

    Args:
        videos: (batch_size, n_frames, n_channels, height, width)

    Returns:
        video_embeddings: (batch_size, n_frames, embedding_dim)
    """
    # simple image encoding without temporal embedding and self attention
    videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1))  # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width))
    video_features = self.video_encoder.visual_projection(self.video_encoder.vision_model(videos)[0])[:, 1:, :]
    video_features = video_features.reshape(
        videos.size(0) // self.config.n_frames,
        self.config.n_frames * video_features.size(1),
        -1).contiguous()

    return video_features
prepare_inputs_for_generation(inputs)

The purpose of this method is to integrate the different modalities into the text embeddings and prepare the associated attention mask and labels for the language model, so the model can generate text conditioned on all the input modalities.

(!!! my hypothesis)

video_frames: (B x F) audios: B x 1 images: B x 1 input_ids: B x L attention_mask: B x L labels: B x L video_starts: B x 1 video_ends: B x 1 audio_starts: B x 1 audio_ends: B x 1 image_starts: B x 1 image_ends: B x 1 inference: True/False

Source code in src/aeiva/model/macaw_model_old.py
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
def prepare_inputs_for_generation(self, inputs):
    """
    The purpose of this method is to integrate the different modalities into the text embeddings 
    and prepare the associated attention mask and labels for the language model, so the model can 
    generate text conditioned on all the input modalities.

    inputs is a dictionary containing the following keys: (!!! my hypothesis)
        video_frames: (B x F)
        audios: B x 1
        images: B x 1
        input_ids: B x L
        attention_mask: B x L
        labels: B x L
        video_starts: B x 1
        video_ends: B x 1
        audio_starts: B x 1
        audio_ends: B x 1
        image_starts: B x 1
        image_ends: B x 1
        inference: True/False
    """
    # get multimodal embeddings
    image_features = self.encode_image(inputs['images']) if inputs['images'] is not None else None
    audio_features = self.encode_audio(inputs['audios']) if inputs['audios'] is not None else None
    video_features = self.encode_video(inputs['videos']) if inputs['videos'] is not None else None
    embed_tokens = self.llm.model.embed_tokens


    # for debug !!!!!!
    # Find maximum id in input_ids
    max_id = torch.max(inputs['input_ids'])
    print(f"Max ID in input_ids: {max_id.item()}")

    # Get vocab size from embedding layer
    vocab_size = embed_tokens.num_embeddings
    print(f"Vocabulary size: {vocab_size}")



    text_embeddings = embed_tokens(inputs['input_ids'])

    token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(
        text_embeddings.size(0), 1, 1).transpose(0, 1)

    # ignore_num seems to be a counter that tracks the total size (or length) of the 
    # multimodal input segments (video, audio, image) added to the original text inputs.
    ingore_num = 0

    # project and merge video features to the same space as text embeddings
    if video_features is not None:
        # get video starts and ends embeddings
        video_starts = embed_tokens(inputs['video_starts']).unsqueeze(1)
        video_ends = embed_tokens(inputs['video_ends']).unsqueeze(1)

        # project video features to the same space as text embeddings
        video_features = self.transform_video_to_hidden(video_features)

        video_features = self.video_align_attention(
            video_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()

        # concatenate video starts, video features, and video ends embeddings
        video_inputs = torch.cat([torch.cat([video_starts, video_features], dim=1), video_ends], dim=1)

        # concatenate video inputs to the original text embeddings
        # NOTE: the first token of text_embeddings keeps at the same position
        text_embeddings = torch.cat([torch.cat([text_embeddings[:, 0, :].unsqueeze(1), video_inputs], dim=1), text_embeddings[:, 1:, :]], dim=1)

        ingore_num += (video_inputs.size(1))

    # project and merge audio features to the same space as text embeddings
    if audio_features is not None:
        # get audio starts and ends embeddings
        audio_starts = embed_tokens(inputs['audio_starts']).unsqueeze(1)
        audio_ends = embed_tokens(inputs['audio_ends']).unsqueeze(1)

        # project audio features to the same space as text embeddings
        audio_features = self.project_audio(audio_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()
        audio_features = self.transform_audio_to_hidden(audio_features)
        # mean pooling
        # audio_features = torch.sum(audio_features, dim=1) / audio_features.size(1) 
        # audio_features = audio_features.unsqueeze(1)
        audio_features = self.audio_align_attention(
            audio_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()

        # concatenate audio starts, audio features, and audio ends embeddings
        audio_inputs = torch.cat([torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)

        # concatenate audio inputs to the original text embeddings
        text_embeddings = torch.cat(
            [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],
            dim=1)

        ingore_num += (audio_inputs.size(1))

    # project and merge image features to the same space as text embeddings
    if image_features is not None:
        # get image starts and ends embeddings
        image_starts = embed_tokens(inputs['image_starts']).unsqueeze(1)
        image_ends = embed_tokens(inputs['image_ends']).unsqueeze(1)

        # project image features to the same space as text embeddings
        image_features = self.project_image(image_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()
        image_features = self.transform_image_to_hidden(image_features)
        image_features = self.image_align_attention(
            image_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()

        # concatenate image starts, image features, and image ends embeddings
        image_inputs = torch.cat([torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)

        # concatenate image inputs to the original text embeddings
        text_embeddings = torch.cat(
            [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1), 
            text_embeddings[:, 1:, :]], dim=1)

        ingore_num += (image_inputs.size(1))

    if 'attention_mask' in inputs:
        # increase the length of attention mask by adding the length of multimodal inputs
        attention_mask = torch.tensor([1]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)  # (B X ignore_num)
        attention_mask = torch.cat([attention_mask, inputs['attention_mask']], dim=1)
    else:
        attention_mask = None

    if 'labels' in inputs and inputs['labels'] is not None:
        # increase the length of labels by adding the length of labels
        # we use -100 to ignore the loss of labels in multimodal inputs
        # !!! we can replace -100 by config constants to make the code better

        # since the tokens corresponding to the image_inputs, audio_inputs, and video_inputs are not part of the original text 
        # and don't have corresponding labels in the true text sequence, their labels are set to -100. This ensures that 
        # the model's predictions for these tokens don't affect the loss and, consequently, the gradients and the model's subsequent learning.
        labels = torch.tensor([-100]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)
        labels = torch.cat([labels, inputs['labels']], dim=1)
    else:
        labels = None

    # text_embeddings: (batch_size, sequence_length + ingore_num, embedding_dim)
    # attention_mask: (batch_size, sequence_length + ingore_num). 1 denotes we should attend to, and 0 denotes we should not attend to the token.
    # labels: (batch_size, sequence_length + ingore_num). -100 denotes we should ignore the token.
    return text_embeddings, attention_mask, labels

MM_LLMs_Config

Bases: PretrainedConfig

This is the configuration class to store the configuration of a MM_LLMsModel. It contains class level and instance level attributes. It also contains the load (from_pretrained) and save (to_dict) methods for saving and loading configuration files.

Source code in src/aeiva/model/macaw_model_old.py
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
class MM_LLMs_Config(PretrainedConfig):
    """
    This is the configuration class to store the configuration of a `MM_LLMsModel`.
    It contains class level and instance level attributes.
    It also contains the load (from_pretrained) and save (to_dict) methods for saving and loading configuration files.
    """
    # general class attributes for all model instances
    model_type = 'mm_llms'
    is_composition = True

    def __init__(self, n_frames=6, attention_heads=8, clip_config=None, whisper_config=None, llm_config=None, **kwargs):
        self.image_config = clip_config
        self.audio_config = whisper_config
        self.llm_config = llm_config  # language model config
        self.n_frames = n_frames  # video config information. How many frames are used for each video clip.
        self.attention_heads = attention_heads
        self.hidden_size = max(llm_config.hidden_size, clip_config.projection_dim, whisper_config.d_model, clip_config.projection_dim)
        super().__init__(**kwargs)

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
        This method overrides the base class method to include serialization of the 
        image, audio, and language model configurations along with the base configuration.

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output["image_config"] = self.image_config.to_dict()
        output["audio_config"] = self.audio_config.to_dict()
        output['llm_config'] = self.llm_config.to_dict()
        output['n_frames'] = self.n_frames
        output['attention_heads'] = self.attention_heads
        output['hidden_size'] = self.hidden_size
        output["model_type"] = self.__class__.model_type
        return output

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        clip_config = CLIPConfig.from_dict(config_dict['image_config'])
        whisper_config = WhisperConfig.from_dict(config_dict['audio_config'])
        llm_config = LlamaConfig.from_dict(config_dict['llm_config'])

        return cls(clip_config=clip_config, whisper_config=whisper_config, llm_config=llm_config, **kwargs)
to_dict()

Serializes this instance to a Python dictionary. Override the default [~PretrainedConfig.to_dict]. This method overrides the base class method to include serialization of the image, audio, and language model configurations along with the base configuration.

Returns:

Type Description

Dict[str, any]: Dictionary of all the attributes that make up this configuration instance,

Source code in src/aeiva/model/macaw_model_old.py
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
def to_dict(self):
    """
    Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
    This method overrides the base class method to include serialization of the 
    image, audio, and language model configurations along with the base configuration.

    Returns:
        `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
    """
    output = copy.deepcopy(self.__dict__)
    output["image_config"] = self.image_config.to_dict()
    output["audio_config"] = self.audio_config.to_dict()
    output['llm_config'] = self.llm_config.to_dict()
    output['n_frames'] = self.n_frames
    output['attention_heads'] = self.attention_heads
    output['hidden_size'] = self.hidden_size
    output["model_type"] = self.__class__.model_type
    return output

WhisperEncoder

Bases: WhisperPreTrainedModel

Transformer encoder consisting of config.encoder_layers self attention layers. Each layer is a [WhisperEncoderLayer].

Parameters:

Name Type Description Default
config WhisperConfig

WhisperConfig

required
Source code in src/aeiva/model/macaw_model_old.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
class WhisperEncoder(WhisperPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`WhisperEncoderLayer`].

    Args:
        config: WhisperConfig
    """

    def __init__(self, config: WhisperConfig):
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        # num_mel_bins corresponds to the number of features extracted from the audio signal for each time step. 
        # When we convert audio to a Mel spectrogram, each time step (or frame) in the spectrogram 
        # is represented by a feature vector of size num_mel_bins. 
        self.num_mel_bins = config.num_mel_bins
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_source_positions
        # embed_scale is a scaling factor that is applied to the embeddings.
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)

        # position embedding layer
        self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)

        self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layer_norm = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def _freeze_parameters(self):
        for param in self.parameters():
            param.requires_grad = False
        self._requires_grad = False

    def get_input_embeddings(self) -> nn.Module:
        return self.conv1

    def set_input_embeddings(self, value: nn.Module):
        self.conv1 = value

    def forward(
        self,
        input_features,
        attention_mask=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Args:
            input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
            attention_mask (`torch.Tensor`)`, *optional*):
                Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
                but it is not used. By default the silence in the input log mel spectrogram are ignored.
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        # set output flags
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # embed audio features
        # input_features shape: (batch_size, feature_size, sequence_length)
        inputs_embeds = nn.functional.gelu(self.conv1(input_features))  # (batch_size, embed_dim, sequence_length)
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))  # (batch_size, embed_dim, sequence_length/2), because the stride is 2. Downsampling by 2.
        inputs_embeds = inputs_embeds.permute(0, 2, 1)  #  (batch_size, sequence_length/2, embed_dim)
        embed_pos = self.embed_positions.weight  # (max_source_positions, embed_dim)

        # add position embedding to audio features embedding
        # !!!: Do max_source_positions and sequence_length/2 must be the same??? Kind of confusing.
        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            assert head_mask.size()[0] == (
                len(self.layers)
            ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."

        # go through the whisper encoder layers to get the hidden states and attentions in all layers
        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        None,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    # The layer_outputs is a tuple of (hidden_states, attention).
                    # The attention is None if output_attentions is False.
                    # hidden_states shape: (batch_size, sequence_length/2, embed_dim), as stride is 2 in the self.conv2
                    # attention shape: (batch_size, num_heads, sequence_length/2, sequence_length/2)
                    layer_outputs = encoder_layer(
                        hidden_states,
                        None,
                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        hidden_states = self.layer_norm(hidden_states)
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        # output
        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )
forward(input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)

Parameters:

Name Type Description Default
input_features `torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`

Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a .flac or .wav audio file into an array of type List[float] or a numpy.ndarray, e.g. via the soundfile library (pip install soundfile). To prepare the array into input_features, the [AutoFeatureExtractor] should be used for extracting the mel features, padding and conversion into a tensor of type torch.FloatTensor. See [~WhisperFeatureExtractor.__call__]

required
attention_mask `torch.Tensor`)`, *optional*

Whisper does not support masking of the input_features, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored.

None
head_mask `torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*

Mask to nullify selected heads of the attention modules. Mask values selected in [0, 1]:

  • 1 indicates the head is not masked,
  • 0 indicates the head is masked.
None
output_attentions `bool`, *optional*

Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

None
output_hidden_states `bool`, *optional*

Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

None
return_dict `bool`, *optional*

Whether or not to return a [~utils.ModelOutput] instead of a plain tuple.

None
Source code in src/aeiva/model/macaw_model_old.py
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
def forward(
    self,
    input_features,
    attention_mask=None,
    head_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    r"""
    Args:
        input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
            Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
            obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
            `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
            `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
            and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
        attention_mask (`torch.Tensor`)`, *optional*):
            Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
            but it is not used. By default the silence in the input log mel spectrogram are ignored.
        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
            for more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
    """
    # set output flags
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # embed audio features
    # input_features shape: (batch_size, feature_size, sequence_length)
    inputs_embeds = nn.functional.gelu(self.conv1(input_features))  # (batch_size, embed_dim, sequence_length)
    inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))  # (batch_size, embed_dim, sequence_length/2), because the stride is 2. Downsampling by 2.
    inputs_embeds = inputs_embeds.permute(0, 2, 1)  #  (batch_size, sequence_length/2, embed_dim)
    embed_pos = self.embed_positions.weight  # (max_source_positions, embed_dim)

    # add position embedding to audio features embedding
    # !!!: Do max_source_positions and sequence_length/2 must be the same??? Kind of confusing.
    hidden_states = inputs_embeds + embed_pos
    hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

    encoder_states = () if output_hidden_states else None
    all_attentions = () if output_attentions else None

    # check if head_mask has a correct number of layers specified if desired
    if head_mask is not None:
        assert head_mask.size()[0] == (
            len(self.layers)
        ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."

    # go through the whisper encoder layers to get the hidden states and attentions in all layers
    for idx, encoder_layer in enumerate(self.layers):
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)
        # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
        dropout_probability = random.uniform(0, 1)
        if self.training and (dropout_probability < self.layerdrop):  # skip the layer
            layer_outputs = (None, None)
        else:
            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(encoder_layer),
                    hidden_states,
                    None,
                    (head_mask[idx] if head_mask is not None else None),
                )
            else:
                # The layer_outputs is a tuple of (hidden_states, attention).
                # The attention is None if output_attentions is False.
                # hidden_states shape: (batch_size, sequence_length/2, embed_dim), as stride is 2 in the self.conv2
                # attention shape: (batch_size, num_heads, sequence_length/2, sequence_length/2)
                layer_outputs = encoder_layer(
                    hidden_states,
                    None,
                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                    output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]

        if output_attentions:
            all_attentions = all_attentions + (layer_outputs[1],)

    hidden_states = self.layer_norm(hidden_states)
    if output_hidden_states:
        encoder_states = encoder_states + (hidden_states,)

    # output
    if not return_dict:
        return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
    return BaseModelOutput(
        last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
    )

rotate_half(x)

Rotates half the hidden dims of the input.

Source code in src/aeiva/model/macaw_model_old.py
125
126
127
128
129
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

operator

custom_ops

macaw_dataitem_ops

This module contains the data item processing functions.

For a data item processing function, it takes a data example (a dict) as input and return a processed data example.

@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

dataitem_ops

This module contains the data item processing functions.

For a data item processing function, it takes a data example (a dict) as input and return a processed data example.

@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

dataset_ops

This module contains the utils for processing datasets.

A dataset in aeiva is a dictionary with the following structure: { "data": [ {sample1}, {sample2}, ..., {sampleN} ], "metadata": { "num_samples": XX, ... } } where each sample is a dictionary itself, and metadata is a dictionary that contains the number of samples and possibly other fields.

@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-13

Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

build_and_merge_datasets(dataset_names, input_filepaths_dict, pipeline, output_dir, max_samples=sys.maxsize)

Build multiple datasets by formatting and processing them.

Source code in src/aeiva/operator/dataset_ops.py
70
71
72
73
74
75
76
77
78
79
80
81
82
def build_and_merge_datasets(dataset_names: list[str],
                             input_filepaths_dict: dict[str, str],
                             pipeline: list[Callable],
                             output_dir: Optional[str],
                             max_samples: Optional[int] = sys.maxsize) -> DataSet:
    r""" Build multiple datasets by formatting and processing them.
    """
    merged_datasets = []
    for dataset_name in dataset_names:
        dataset = build_dataset(dataset_name, input_filepaths_dict, pipeline, output_dir, max_samples)
        merged_datasets.append(dataset)
    result = merge_datasets(merged_datasets)
    return result

build_dataset(dataset_name, input_filepaths_dict, pipeline, output_dir, max_samples=sys.maxsize)

Build a dataset by formatting and processing it.

Source code in src/aeiva/operator/dataset_ops.py
43
44
45
46
47
48
49
50
51
52
53
54
55
def build_dataset(dataset_name: str,
                  input_filepaths_dict: dict[str, str],
                  pipeline: list[Callable],
                  output_dir: Optional[str],
                  max_samples: Optional[int] = sys.maxsize) -> DataSet:
    r""" Build a dataset by formatting and processing it.
    """
    operator_type = 'data_formatter'
    format_func = OPERATORS[operator_type][dataset_name]
    formatted_dataset = format_func(input_filepaths_dict, output_dir, max_samples)
    processed_dataset = process_dataset(formatted_dataset, pipeline, output_dir, dataset_name)
    print(f"Completed processing dataset: {dataset_name} (output_dir: {output_dir})")
    return processed_dataset

filter_dataset(dataset, filter_criteria, *args, **kwargs)

Filter a dataset by a filter function.

Source code in src/aeiva/operator/dataset_ops.py
93
94
95
96
97
98
99
def filter_dataset(dataset: DataSet, filter_criteria: str, *args, **kwargs) -> DataSet:
    r""" Filter a dataset by a filter function.
    """
    operator_type = 'data_filter'
    filter_func = OPERATORS[operator_type][filter_criteria]
    filtered_data = filter_func(dataset, *args, **kwargs)
    return filtered_data

filter_dataset_by_keys(dataset, keys_to_preserve)

Filter the dataset to only include specified keys in each sample.

Source code in src/aeiva/operator/dataset_ops.py
102
103
104
105
106
107
108
109
110
111
112
113
@register_data_filter("filter_dataset_by_keys")
def filter_dataset_by_keys(dataset: DataSet, keys_to_preserve: list[str]) -> DataSet:
    r""" Filter the dataset to only include specified keys in each sample.
    """
    filtered_data = []
    for sample in dataset["data"]:
        for key in keys_to_preserve:
            if key not in sample:
                raise KeyError(f"Key {key} not found in sample")
        filtered_sample = {key: sample[key] for key in keys_to_preserve if key in sample}
        filtered_data.append(filtered_sample)
    return {"data": filtered_data, "metadata": dataset["metadata"]}

merge_datasets(datasets)

Merge multiple datasets into one.

Source code in src/aeiva/operator/dataset_ops.py
58
59
60
61
62
63
64
65
66
67
def merge_datasets(datasets: list[DataSet]) -> DataSet:
    r""" Merge multiple datasets into one.
    """
    merged_data = []
    total_samples = 0
    for dataset in datasets:
        merged_data.extend(dataset["data"])
        total_samples += dataset["metadata"]["num_samples"]
    result = {"data": merged_data, "metadata": {"num_samples": total_samples}}
    return result

sample_dataset(dataset, n_samples)

Sample a number of samples from a dataset.

Source code in src/aeiva/operator/dataset_ops.py
85
86
87
88
89
90
def sample_dataset(dataset: DataSet, n_samples: int) -> DataSet:
    r""" Sample a number of samples from a dataset.
    """
    random_indices = random.sample(range(dataset["metadata"]["num_samples"]), n_samples)
    sampled_data = [dataset["data"][i] for i in random_indices]
    return {"data": sampled_data, "metadata": {"num_samples": n_samples}}

save_dataset(dataset, output_path)

Save a dataset to a file by pickling it.

Source code in src/aeiva/operator/dataset_ops.py
148
149
150
151
152
def save_dataset(dataset: DataSet, output_path: str) -> None:
    r""" Save a dataset to a file by pickling it.
    """
    ensure_dir(output_path)
    pickle.dump(dataset, open(output_path, "wb"), protocol=4)

split_dataset(dataset, train_ratio, seed=42)

Split a dataset into a training set and a validation set.

Source code in src/aeiva/operator/dataset_ops.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def split_dataset(dataset: dict, train_ratio: float, seed: int = 42) -> Tuple[dict]:
    r""" Split a dataset into a training set and a validation set.
    """
    np.random.seed(seed)  # ensures the function is deterministic

    data = dataset["data"]
    metadata = dataset["metadata"]

    # Create a permutation of indices and shuffle the data.
    perm = np.random.permutation(len(data))
    shuffled_data = [data[i] for i in perm]

    # Calculate split index
    split_idx = int(train_ratio * len(shuffled_data))

    # Split the shuffled data
    train_data = shuffled_data[:split_idx]
    val_data = shuffled_data[split_idx:]

    # Create metadata for training and validation datasets
    train_metadata = metadata.copy()
    train_metadata["num_samples"] = len(train_data)
    val_metadata = metadata.copy()
    val_metadata["num_samples"] = len(val_data)

    # Create training and validation datasets
    train_dataset = {"data": train_data, "metadata": train_metadata}
    val_dataset = {"data": val_data, "metadata": val_metadata}

    return train_dataset, val_dataset

perception

base_perception_system

PerceptionSystem

Bases: ABC

Abstract base class representing the Perception System of an agent.

The Perception System is responsible for capturing raw sensory data from the environment, processing this data into meaningful observations, and providing access to these observations for other components of the cognitive architecture.

Attributes:

Name Type Description
config Any

Configuration settings for the Perception System.

state Any

The internal state of the Perception System, including raw data and observations.

Source code in src/aeiva/perception/base_perception_system.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class PerceptionSystem(ABC):
    """
    Abstract base class representing the Perception System of an agent.

    The Perception System is responsible for capturing raw sensory data from the environment,
    processing this data into meaningful observations, and providing access to these observations
    for other components of the cognitive architecture.

    Attributes:
        config (Any): Configuration settings for the Perception System.
        state (Any): The internal state of the Perception System, including raw data and observations.
    """

    def __init__(self, config: Any):
        """
        Initialize the Perception System with the provided configuration.

        Args:
            config (Any): Configuration settings for the Perception System.
        """
        self.config = config
        self.state = self.init_state()

    @abstractmethod
    def init_state(self) -> Any:
        """
        Initialize the internal state of the Perception System.

        This method should set up the initial state required for the Perception System's operations.

        Returns:
            Any: The initial state of the Perception System.
        """
        pass

    @abstractmethod
    async def setup(self) -> None:
        """
        Asynchronously set up the Perception System's components.

        This method should initialize any necessary components or resources based on the provided configuration.

        Raises:
            ConfigurationError: If the configuration is invalid or incomplete.
        """
        pass

    @abstractmethod
    async def capture(self, raw_data: Any) -> None:
        """
        Asynchronously capture raw sensory data from the environment.

        Args:
            raw_data (Any): The raw sensory data to capture.

        Raises:
            CaptureError: If capturing the raw data fails.
        """
        pass

    @abstractmethod
    async def process(self) -> None:
        """
        Asynchronously process the captured raw sensory data into meaningful observations.

        This method should transform raw data stored in the internal state into structured observations
        that can be utilized by other components of the cognitive architecture.

        Raises:
            ProcessingError: If processing the raw data fails.
        """
        pass

    async def perceive(self, raw_data: Any) -> None:
        """
        Asynchronously perform the full perception cycle: capture and process raw sensory data.

        Args:
            raw_data (Any): The raw sensory data to perceive.

        Raises:
            CaptureError: If capturing the raw data fails.
            ProcessingError: If processing the raw data fails.
        """
        try:
            await self.capture(raw_data)
            await self.process()
        except Exception as e:
            self.handle_error(e)
            raise e

    def get_observations(self) -> Any:
        """
        Retrieve the current processed observations from the Perception System.

        Returns:
            Any: The current observations.
        """
        return self.state.get("observations", None)

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during perception operations.

        This method can be overridden to implement custom error handling logic, such as logging
        or retry mechanisms.

        Args:
            error (Exception): The exception that was raised.
        """
        # Default error handling: log the error
        print(f"PerceptionSystem encountered an error: {error}")
__init__(config)

Initialize the Perception System with the provided configuration.

Parameters:

Name Type Description Default
config Any

Configuration settings for the Perception System.

required
Source code in src/aeiva/perception/base_perception_system.py
20
21
22
23
24
25
26
27
28
def __init__(self, config: Any):
    """
    Initialize the Perception System with the provided configuration.

    Args:
        config (Any): Configuration settings for the Perception System.
    """
    self.config = config
    self.state = self.init_state()
capture(raw_data) abstractmethod async

Asynchronously capture raw sensory data from the environment.

Parameters:

Name Type Description Default
raw_data Any

The raw sensory data to capture.

required

Raises:

Type Description
CaptureError

If capturing the raw data fails.

Source code in src/aeiva/perception/base_perception_system.py
54
55
56
57
58
59
60
61
62
63
64
65
@abstractmethod
async def capture(self, raw_data: Any) -> None:
    """
    Asynchronously capture raw sensory data from the environment.

    Args:
        raw_data (Any): The raw sensory data to capture.

    Raises:
        CaptureError: If capturing the raw data fails.
    """
    pass
get_observations()

Retrieve the current processed observations from the Perception System.

Returns:

Name Type Description
Any Any

The current observations.

Source code in src/aeiva/perception/base_perception_system.py
 98
 99
100
101
102
103
104
105
def get_observations(self) -> Any:
    """
    Retrieve the current processed observations from the Perception System.

    Returns:
        Any: The current observations.
    """
    return self.state.get("observations", None)
handle_error(error)

Handle errors that occur during perception operations.

This method can be overridden to implement custom error handling logic, such as logging or retry mechanisms.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/perception/base_perception_system.py
107
108
109
110
111
112
113
114
115
116
117
118
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during perception operations.

    This method can be overridden to implement custom error handling logic, such as logging
    or retry mechanisms.

    Args:
        error (Exception): The exception that was raised.
    """
    # Default error handling: log the error
    print(f"PerceptionSystem encountered an error: {error}")
init_state() abstractmethod

Initialize the internal state of the Perception System.

This method should set up the initial state required for the Perception System's operations.

Returns:

Name Type Description
Any Any

The initial state of the Perception System.

Source code in src/aeiva/perception/base_perception_system.py
30
31
32
33
34
35
36
37
38
39
40
@abstractmethod
def init_state(self) -> Any:
    """
    Initialize the internal state of the Perception System.

    This method should set up the initial state required for the Perception System's operations.

    Returns:
        Any: The initial state of the Perception System.
    """
    pass
perceive(raw_data) async

Asynchronously perform the full perception cycle: capture and process raw sensory data.

Parameters:

Name Type Description Default
raw_data Any

The raw sensory data to perceive.

required

Raises:

Type Description
CaptureError

If capturing the raw data fails.

ProcessingError

If processing the raw data fails.

Source code in src/aeiva/perception/base_perception_system.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
async def perceive(self, raw_data: Any) -> None:
    """
    Asynchronously perform the full perception cycle: capture and process raw sensory data.

    Args:
        raw_data (Any): The raw sensory data to perceive.

    Raises:
        CaptureError: If capturing the raw data fails.
        ProcessingError: If processing the raw data fails.
    """
    try:
        await self.capture(raw_data)
        await self.process()
    except Exception as e:
        self.handle_error(e)
        raise e
process() abstractmethod async

Asynchronously process the captured raw sensory data into meaningful observations.

This method should transform raw data stored in the internal state into structured observations that can be utilized by other components of the cognitive architecture.

Raises:

Type Description
ProcessingError

If processing the raw data fails.

Source code in src/aeiva/perception/base_perception_system.py
67
68
69
70
71
72
73
74
75
76
77
78
@abstractmethod
async def process(self) -> None:
    """
    Asynchronously process the captured raw sensory data into meaningful observations.

    This method should transform raw data stored in the internal state into structured observations
    that can be utilized by other components of the cognitive architecture.

    Raises:
        ProcessingError: If processing the raw data fails.
    """
    pass
setup() abstractmethod async

Asynchronously set up the Perception System's components.

This method should initialize any necessary components or resources based on the provided configuration.

Raises:

Type Description
ConfigurationError

If the configuration is invalid or incomplete.

Source code in src/aeiva/perception/base_perception_system.py
42
43
44
45
46
47
48
49
50
51
52
@abstractmethod
async def setup(self) -> None:
    """
    Asynchronously set up the Perception System's components.

    This method should initialize any necessary components or resources based on the provided configuration.

    Raises:
        ConfigurationError: If the configuration is invalid or incomplete.
    """
    pass

perception_system

PerceptionSystem

Manages multiple sensors and emits stimuli via the EventBus.

Source code in src/aeiva/perception/perception_system.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class PerceptionSystem:
    """
    Manages multiple sensors and emits stimuli via the EventBus.
    """
    def __init__(self, config: Dict, event_bus):
        """
        Initializes the PerceptionSystem with a list of sensors.

        Args:
            config (Any): Configuration dictionary for the sensors.
            event_bus: The EventBus instance for emitting events.
        """
        self.config = config
        self.event_bus = event_bus
        self.sensors: List[Sensor] = []
        self.logger = logging.getLogger('PerceptionSystem')

    def setup(self) -> None:
        """
        Sets up the perception system by initializing all configured sensors.
        """
        for sensor_config in self.config.get("sensors", []):
            sensor_name = sensor_config.get("sensor_name")
            sensor_params = sensor_config.get("sensor_params", {})
            # TODO: revise later
            if sensor_name == 'percept_terminal_input':
                sensor = TerminalInputSensor(sensor_name, sensor_params, self.event_bus)
                self.sensors.append(sensor)
            else:
                self.logger.warning(f"Unknown sensor type: {sensor_name}")
        self.logger.info("PerceptionSystem setup complete.")

    async def start(self) -> None:  # TODO: maybe rename in the future
        """
        Starts all sensors asynchronously.
        """
        self.logger.info("Starting all sensors.")
        for sensor in self.sensors:
            await sensor.start()

    async def stop(self) -> None:
        """
        Stops all sensors asynchronously.
        """
        self.logger.info("Stopping all sensors.")
        for sensor in self.sensors:
            await sensor.stop()

    def signal_to_stimuli(self, data: Any) -> Any:
        """
        Processes raw data from sensors into structured stimuli.

        Args:
            data: The raw data emitted by sensors.

        Returns:
            Processed data (stimuli).
        """
        # Implement your data processing logic here
        signal = Signal(
            data=data,
            modularity="text",  # Or appropriate modality
            type="input",       # Or appropriate type
            # TODO: After revised Sensor class, Include other metadata as needed
        )
        stimuli = Stimuli(signals=[signal])  # TODO: add more fields
        return stimuli
__init__(config, event_bus)

Initializes the PerceptionSystem with a list of sensors.

Parameters:

Name Type Description Default
config Any

Configuration dictionary for the sensors.

required
event_bus

The EventBus instance for emitting events.

required
Source code in src/aeiva/perception/perception_system.py
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self, config: Dict, event_bus):
    """
    Initializes the PerceptionSystem with a list of sensors.

    Args:
        config (Any): Configuration dictionary for the sensors.
        event_bus: The EventBus instance for emitting events.
    """
    self.config = config
    self.event_bus = event_bus
    self.sensors: List[Sensor] = []
    self.logger = logging.getLogger('PerceptionSystem')
setup()

Sets up the perception system by initializing all configured sensors.

Source code in src/aeiva/perception/perception_system.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def setup(self) -> None:
    """
    Sets up the perception system by initializing all configured sensors.
    """
    for sensor_config in self.config.get("sensors", []):
        sensor_name = sensor_config.get("sensor_name")
        sensor_params = sensor_config.get("sensor_params", {})
        # TODO: revise later
        if sensor_name == 'percept_terminal_input':
            sensor = TerminalInputSensor(sensor_name, sensor_params, self.event_bus)
            self.sensors.append(sensor)
        else:
            self.logger.warning(f"Unknown sensor type: {sensor_name}")
    self.logger.info("PerceptionSystem setup complete.")
signal_to_stimuli(data)

Processes raw data from sensors into structured stimuli.

Parameters:

Name Type Description Default
data Any

The raw data emitted by sensors.

required

Returns:

Type Description
Any

Processed data (stimuli).

Source code in src/aeiva/perception/perception_system.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def signal_to_stimuli(self, data: Any) -> Any:
    """
    Processes raw data from sensors into structured stimuli.

    Args:
        data: The raw data emitted by sensors.

    Returns:
        Processed data (stimuli).
    """
    # Implement your data processing logic here
    signal = Signal(
        data=data,
        modularity="text",  # Or appropriate modality
        type="input",       # Or appropriate type
        # TODO: After revised Sensor class, Include other metadata as needed
    )
    stimuli = Stimuli(signals=[signal])  # TODO: add more fields
    return stimuli
start() async

Starts all sensors asynchronously.

Source code in src/aeiva/perception/perception_system.py
43
44
45
46
47
48
49
async def start(self) -> None:  # TODO: maybe rename in the future
    """
    Starts all sensors asynchronously.
    """
    self.logger.info("Starting all sensors.")
    for sensor in self.sensors:
        await sensor.start()
stop() async

Stops all sensors asynchronously.

Source code in src/aeiva/perception/perception_system.py
51
52
53
54
55
56
57
async def stop(self) -> None:
    """
    Stops all sensors asynchronously.
    """
    self.logger.info("Stopping all sensors.")
    for sensor in self.sensors:
        await sensor.stop()

sensation

Signal

Represents an atomic unit of perception that carries raw data from the environment. This class defines a signal, its characteristics, and its dependencies on other signals.

Source code in src/aeiva/perception/sensation.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class Signal:
    """
    Represents an atomic unit of perception that carries raw data from the environment.
    This class defines a signal, its characteristics, and its dependencies on other signals.
    """

    def __init__(self, 
                 data: Any,
                 name: Optional[str] = None,  # Optional name for the signal
                 modularity: Optional[str] = None,
                 type: Optional[str] = None,  # Renamed to avoid keyword conflict
                 timestamp: Optional[datetime] = None,
                 id: Optional[str] = None,  # Optional unique identifier for the signal
                 dependencies: Optional[Dict[str, Any]] = None,  # Dependencies by other signal IDs with edge attributes
                 description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None):
        """
        Initialize a signal with its data and other optional metadata.

        Args:
            data (Any): The raw data of the signal.
            name (Optional[str]): An optional name for the signal.
            modularity (Optional[str]): The modality of the signal (e.g., image, video, text, audio).
            type (Optional[str]): A more detailed signal type (e.g., 'text', 'document', etc.).
            timestamp (Optional[datetime]): The time when the signal was created or captured.
            id (Optional[str]): Unique identifier for the signal.
            dependencies (Optional[Dict[str, Any]]): Attributes of dependencies (e.g., relationship types).
            description (Optional[str]): Description of the signal.
            metadata (Optional[Dict[str, Any]]): Optional additional metadata for the signal.
        """
        self.data = data
        self.name = name
        self.modularity = modularity
        self.type = type
        self.timestamp = timestamp or datetime.now()
        self.id = id
        self.dependencies = dependencies or {}  # Edge attributes (could be string, embedding, etc.)
        self.description = description
        self.metadata = metadata or {}

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the signal into a dictionary representation.
        """
        return {
            "data": self.data,
            "name": self.name,
            "modularity": self.modularity,
            "type": self.type,
            "timestamp": self.timestamp,
            "id": self.id,
            "dependencies": self.dependencies,
            "description": self.description,
            "metadata": self.metadata
        }
__init__(data, name=None, modularity=None, type=None, timestamp=None, id=None, dependencies=None, description=None, metadata=None)

Initialize a signal with its data and other optional metadata.

Parameters:

Name Type Description Default
data Any

The raw data of the signal.

required
name Optional[str]

An optional name for the signal.

None
modularity Optional[str]

The modality of the signal (e.g., image, video, text, audio).

None
type Optional[str]

A more detailed signal type (e.g., 'text', 'document', etc.).

None
timestamp Optional[datetime]

The time when the signal was created or captured.

None
id Optional[str]

Unique identifier for the signal.

None
dependencies Optional[Dict[str, Any]]

Attributes of dependencies (e.g., relationship types).

None
description Optional[str]

Description of the signal.

None
metadata Optional[Dict[str, Any]]

Optional additional metadata for the signal.

None
Source code in src/aeiva/perception/sensation.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(self, 
             data: Any,
             name: Optional[str] = None,  # Optional name for the signal
             modularity: Optional[str] = None,
             type: Optional[str] = None,  # Renamed to avoid keyword conflict
             timestamp: Optional[datetime] = None,
             id: Optional[str] = None,  # Optional unique identifier for the signal
             dependencies: Optional[Dict[str, Any]] = None,  # Dependencies by other signal IDs with edge attributes
             description: Optional[str] = None,
             metadata: Optional[Dict[str, Any]] = None):
    """
    Initialize a signal with its data and other optional metadata.

    Args:
        data (Any): The raw data of the signal.
        name (Optional[str]): An optional name for the signal.
        modularity (Optional[str]): The modality of the signal (e.g., image, video, text, audio).
        type (Optional[str]): A more detailed signal type (e.g., 'text', 'document', etc.).
        timestamp (Optional[datetime]): The time when the signal was created or captured.
        id (Optional[str]): Unique identifier for the signal.
        dependencies (Optional[Dict[str, Any]]): Attributes of dependencies (e.g., relationship types).
        description (Optional[str]): Description of the signal.
        metadata (Optional[Dict[str, Any]]): Optional additional metadata for the signal.
    """
    self.data = data
    self.name = name
    self.modularity = modularity
    self.type = type
    self.timestamp = timestamp or datetime.now()
    self.id = id
    self.dependencies = dependencies or {}  # Edge attributes (could be string, embedding, etc.)
    self.description = description
    self.metadata = metadata or {}
to_dict()

Converts the signal into a dictionary representation.

Source code in src/aeiva/perception/sensation.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def to_dict(self) -> Dict[str, Any]:
    """
    Converts the signal into a dictionary representation.
    """
    return {
        "data": self.data,
        "name": self.name,
        "modularity": self.modularity,
        "type": self.type,
        "timestamp": self.timestamp,
        "id": self.id,
        "dependencies": self.dependencies,
        "description": self.description,
        "metadata": self.metadata
    }

sensor

Sensor

Bases: ABC

Abstract base class for all sensors.

Source code in src/aeiva/perception/sensor.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Sensor(ABC):
    """
    Abstract base class for all sensors.
    """
    def __init__(self, name: str, params: dict, event_bus):
        """
        Initializes the BaseSensor.

        Args:
            name (str): The name of the sensor.
            params (dict): Configuration parameters for the sensor.
            event_bus: The EventBus instance for emitting events.
        """
        self.name = name
        self.params = params
        self.event_bus = event_bus

    @abstractmethod
    async def start(self):
        """
        Starts the sensor.
        """
        pass

    @abstractmethod
    async def stop(self):
        """
        Stops the sensor.
        """
        pass
__init__(name, params, event_bus)

Initializes the BaseSensor.

Parameters:

Name Type Description Default
name str

The name of the sensor.

required
params dict

Configuration parameters for the sensor.

required
event_bus

The EventBus instance for emitting events.

required
Source code in src/aeiva/perception/sensor.py
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(self, name: str, params: dict, event_bus):
    """
    Initializes the BaseSensor.

    Args:
        name (str): The name of the sensor.
        params (dict): Configuration parameters for the sensor.
        event_bus: The EventBus instance for emitting events.
    """
    self.name = name
    self.params = params
    self.event_bus = event_bus
start() abstractmethod async

Starts the sensor.

Source code in src/aeiva/perception/sensor.py
23
24
25
26
27
28
@abstractmethod
async def start(self):
    """
    Starts the sensor.
    """
    pass
stop() abstractmethod async

Stops the sensor.

Source code in src/aeiva/perception/sensor.py
30
31
32
33
34
35
@abstractmethod
async def stop(self):
    """
    Stops the sensor.
    """
    pass

stimuli

Stimuli

Represents a structured composition of signals, where each node can be a Signal or a sub-Stimuli. The graph allows flexible, directed relationships between nodes, and the graph can contain cycles.

Source code in src/aeiva/perception/stimuli.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class Stimuli:
    """
    Represents a structured composition of signals, where each node can be a Signal or a sub-Stimuli.
    The graph allows flexible, directed relationships between nodes, and the graph can contain cycles.
    """

    def __init__(self, 
                 signals: List[Union[Signal, 'Stimuli']],
                 id: Optional[str] = None,
                 name: Optional[str] = None,
                 type: Optional[str] = None,
                 modularity: Optional[str] = None,
                 timestamp: Optional[str] = None,
                 dependencies: Optional[Dict[str, Dict[str, Any]]] = None,
                 description: Optional[str] = None,
                 metadata: Optional[Dict[str, Any]] = None):
        """
        Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.
        """
        self.signals = signals or []  # Default to an empty list if no signals provided
        self.id = id
        self.name = name
        self.type = type
        self.modularity = modularity
        self.timestamp = timestamp
        self.description = description
        self.metadata = metadata or {}
        self.dependencies = dependencies or {}

        # Graph to represent the structure of signals and their relationships
        self.graph = nx.DiGraph()

        # Add all signals and sub-stimuli as nodes in the graph
        for signal in signals:
            self.graph.add_node(signal)

        # Handle dependencies for signals or sub-stimuli
        for signal in signals:
            if signal.id in self.dependencies:
                for dep_id, edge_attr in self.dependencies[signal.id].items():
                    dep_node = next((s for s in signals if s.id == dep_id), None)
                    if dep_node and (isinstance(dep_node, Signal) or isinstance(dep_node, Stimuli)):
                        self.graph.add_edge(dep_node, signal, **edge_attr)
                    else:
                        raise ValueError(f"Dependency {dep_id} not found or is not valid for signal or stimuli {signal.id}.")

    def traverse(self, method: str = 'dfs') -> List[Union[Signal, 'Stimuli']]:
        """
        Traverses the graph using the specified method ('dfs' or 'bfs').

        Args:
            method (str): The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).

        Returns:
            List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.
        """
        if not self.graph.nodes:
            return []

        if method == 'dfs':
            return list(nx.dfs_postorder_nodes(self.graph))
        elif method == 'bfs':
            return list(nx.bfs_tree(self.graph, list(self.graph.nodes)[0]))  # BFS starting from an arbitrary node
        else:
            raise ValueError(f"Unknown traversal method: {method}")

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the stimuli into a dictionary representation, including its signals and their relationships.
        """
        return {
            "id": self.id,
            "name": self.name,
            "type": self.type,
            "modularity": self.modularity,
            "timestamp": self.timestamp,
            "description": self.description,
            "metadata": self.metadata,
            "signals": [signal.to_dict() if isinstance(signal, Signal) else signal.to_dict() for signal in self.signals],
            "dependencies": self.dependencies
        }

    def visualize(self, save_path: Optional[str] = None):
        """
        Visualizes the procedure's structure using networkx and matplotlib.
        """
        pos = nx.spring_layout(self.graph)  # Layout for the graph
        labels = {node: f"{node.id} ({node.type})" if isinstance(node, Signal) else f"{node.id} (Stimuli)"
                  for node in self.graph.nodes()}

        # Draw the graph with labels
        nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)

        plt.title(f"{self.type} {self.description} Visualization")
        if save_path:
            plt.savefig(save_path)
        else:
            plt.show()
__init__(signals, id=None, name=None, type=None, modularity=None, timestamp=None, dependencies=None, description=None, metadata=None)

Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.

Source code in src/aeiva/perception/stimuli.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(self, 
             signals: List[Union[Signal, 'Stimuli']],
             id: Optional[str] = None,
             name: Optional[str] = None,
             type: Optional[str] = None,
             modularity: Optional[str] = None,
             timestamp: Optional[str] = None,
             dependencies: Optional[Dict[str, Dict[str, Any]]] = None,
             description: Optional[str] = None,
             metadata: Optional[Dict[str, Any]] = None):
    """
    Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.
    """
    self.signals = signals or []  # Default to an empty list if no signals provided
    self.id = id
    self.name = name
    self.type = type
    self.modularity = modularity
    self.timestamp = timestamp
    self.description = description
    self.metadata = metadata or {}
    self.dependencies = dependencies or {}

    # Graph to represent the structure of signals and their relationships
    self.graph = nx.DiGraph()

    # Add all signals and sub-stimuli as nodes in the graph
    for signal in signals:
        self.graph.add_node(signal)

    # Handle dependencies for signals or sub-stimuli
    for signal in signals:
        if signal.id in self.dependencies:
            for dep_id, edge_attr in self.dependencies[signal.id].items():
                dep_node = next((s for s in signals if s.id == dep_id), None)
                if dep_node and (isinstance(dep_node, Signal) or isinstance(dep_node, Stimuli)):
                    self.graph.add_edge(dep_node, signal, **edge_attr)
                else:
                    raise ValueError(f"Dependency {dep_id} not found or is not valid for signal or stimuli {signal.id}.")
to_dict()

Converts the stimuli into a dictionary representation, including its signals and their relationships.

Source code in src/aeiva/perception/stimuli.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def to_dict(self) -> Dict[str, Any]:
    """
    Converts the stimuli into a dictionary representation, including its signals and their relationships.
    """
    return {
        "id": self.id,
        "name": self.name,
        "type": self.type,
        "modularity": self.modularity,
        "timestamp": self.timestamp,
        "description": self.description,
        "metadata": self.metadata,
        "signals": [signal.to_dict() if isinstance(signal, Signal) else signal.to_dict() for signal in self.signals],
        "dependencies": self.dependencies
    }
traverse(method='dfs')

Traverses the graph using the specified method ('dfs' or 'bfs').

Parameters:

Name Type Description Default
method str

The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).

'dfs'

Returns:

Type Description
List[Union[Signal, Stimuli]]

List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.

Source code in src/aeiva/perception/stimuli.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def traverse(self, method: str = 'dfs') -> List[Union[Signal, 'Stimuli']]:
    """
    Traverses the graph using the specified method ('dfs' or 'bfs').

    Args:
        method (str): The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).

    Returns:
        List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.
    """
    if not self.graph.nodes:
        return []

    if method == 'dfs':
        return list(nx.dfs_postorder_nodes(self.graph))
    elif method == 'bfs':
        return list(nx.bfs_tree(self.graph, list(self.graph.nodes)[0]))  # BFS starting from an arbitrary node
    else:
        raise ValueError(f"Unknown traversal method: {method}")
visualize(save_path=None)

Visualizes the procedure's structure using networkx and matplotlib.

Source code in src/aeiva/perception/stimuli.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def visualize(self, save_path: Optional[str] = None):
    """
    Visualizes the procedure's structure using networkx and matplotlib.
    """
    pos = nx.spring_layout(self.graph)  # Layout for the graph
    labels = {node: f"{node.id} ({node.type})" if isinstance(node, Signal) else f"{node.id} (Stimuli)"
              for node in self.graph.nodes()}

    # Draw the graph with labels
    nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)

    plt.title(f"{self.type} {self.description} Visualization")
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

terminal_input_sensor

TerminalInputSensor

Bases: Sensor

A sensor that reads input from the terminal and emits stimuli via the EventBus.

Source code in src/aeiva/perception/terminal_input_sensor.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class TerminalInputSensor(Sensor):
    """
    A sensor that reads input from the terminal and emits stimuli via the EventBus.
    """
    def __init__(self, name: str, params: dict, event_bus):
        super().__init__(name, params, event_bus)
        self.prompt_message = params.get('prompt_message', 'You: ')
        self._running = False
        self._thread = None
        # self.logger = logging.getLogger(f'TerminalInputSensor-{self.name}')

    async def start(self):
        """
        Starts the sensor by launching the input thread.
        """
        self._running = True
        self._thread = threading.Thread(target=self._run, daemon=True)
        self._thread.start()
        # self.logger.info(f"{self.name} started.")

    async def stop(self):
        """
        Stops the sensor by signaling the thread to stop and waiting for it to finish.
        """
        self._running = False
        if self._thread:
            self._thread.join()
            # self.logger.info(f"{self.name} stopped.")

    def _run(self):
        """
        The main loop that reads user input and emits events.
        """
        loop = self.event_bus.loop
        if loop is None:
            # self.logger.error("EventBus loop is not set. Cannot emit events.")
            return

        while self._running:
            try:
                user_input = input(self.prompt_message)
                if not self._running:
                    break  # Exit if stopped during input

                # # Process input into stimuli
                # stimuli = self.signal_to_stimuli(user_input)

                # Emit the stimuli as an event
                asyncio.run_coroutine_threadsafe(
                    self.event_bus.emit('perception.stimuli', payload=user_input),  # TODO: rename event later
                    loop
                )
            except EOFError:
                # Handle end of input (Ctrl+D)
                # self.logger.info("EOF received. Stopping TerminalInputSensor.")
                self._running = False
            except KeyboardInterrupt:
                # Handle Ctrl+C
                # self.logger.info("KeyboardInterrupt received. Stopping TerminalInputSensor.")
                self._running = False
            except Exception as e:
                # self.logger.error(f"Error in TerminalInputSensor: {e}")
                self._running = False
start() async

Starts the sensor by launching the input thread.

Source code in src/aeiva/perception/terminal_input_sensor.py
20
21
22
23
24
25
26
async def start(self):
    """
    Starts the sensor by launching the input thread.
    """
    self._running = True
    self._thread = threading.Thread(target=self._run, daemon=True)
    self._thread.start()
stop() async

Stops the sensor by signaling the thread to stop and waiting for it to finish.

Source code in src/aeiva/perception/terminal_input_sensor.py
29
30
31
32
33
34
35
async def stop(self):
    """
    Stops the sensor by signaling the thread to stop and waiting for it to finish.
    """
    self._running = False
    if self._thread:
        self._thread.join()

test

handle_observation(stimuli) async

Processes stimuli using the cognition system and outputs the response.

Source code in src/aeiva/perception/test.py
25
26
27
28
29
30
31
32
33
async def handle_observation(stimuli):
    """
    Processes stimuli using the cognition system and outputs the response.
    """
    for signal in stimuli.signals:
        user_input = signal.data
        stimuli_data = [{"role": "user", "content": user_input}]
        response = await llm_brain.think(stimuli_data, stream=True)
        print(f"LLM Response: {response}")

plugin

ability

plugin_a

plugin
PluginA

Bases: Plugin

Example Plugin A.

Source code in src/aeiva/plugin/ability/plugin_a/plugin.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class PluginA(Plugin):
    """
    Example Plugin A.
    """

    def activate(self) -> None:
        print("PluginA activated.")

    def deactivate(self) -> None:
        print("PluginA deactivated.")

    def run(self) -> None:
        print("PluginA is running.")

plugin_b

plugin
PluginB

Bases: Plugin

Example Plugin B.

Source code in src/aeiva/plugin/ability/plugin_b/plugin.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class PluginB(Plugin):
    """
    Example Plugin B.
    """

    def activate(self) -> None:
        print("PluginB activated.")

    def deactivate(self) -> None:
        print("PluginB deactivated.")

    def run(self) -> None:
        print("PluginB is running.")

plug

Plug Module

This module provides a flexible plugin system with support for:

  • Multiple plugin sources with isolation
  • Context managers and import hooks
  • Resource loading from plugins
  • Loading plugins from directories and zip files
  • Hot swapping and lazy loading of plugins

Author: Bang Liu Date: 2024-11-19

Plugin

Bases: ABC

Abstract base class that all plugins must inherit from.

Source code in src/aeiva/plugin/plug.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Plugin(abc.ABC):
    """
    Abstract base class that all plugins must inherit from.
    """

    @abc.abstractmethod
    def activate(self) -> None:
        """Method called when the plugin is activated."""
        pass

    @abc.abstractmethod
    def deactivate(self) -> None:
        """Method called when the plugin is deactivated."""
        pass
activate() abstractmethod

Method called when the plugin is activated.

Source code in src/aeiva/plugin/plug.py
35
36
37
38
@abc.abstractmethod
def activate(self) -> None:
    """Method called when the plugin is activated."""
    pass
deactivate() abstractmethod

Method called when the plugin is deactivated.

Source code in src/aeiva/plugin/plug.py
40
41
42
43
@abc.abstractmethod
def deactivate(self) -> None:
    """Method called when the plugin is deactivated."""
    pass

PluginFinder

Bases: MetaPathFinder

Custom finder for plugin modules. Finds plugins as directories containing a plugin.py file.

Source code in src/aeiva/plugin/plug.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class PluginFinder(importlib.abc.MetaPathFinder):
    """
    Custom finder for plugin modules.
    Finds plugins as directories containing a `plugin.py` file.
    """

    def __init__(self, plugin_source: 'PluginSource') -> None:
        self.plugin_source = plugin_source

    def find_spec(
        self,
        fullname: str,
        path: Optional[List[str]],
        target: Optional[ModuleType] = None
    ) -> Optional[importlib.machinery.ModuleSpec]:
        """
        Find the module spec for the given module.
        Handles both the namespace package and its submodules (plugins).
        """
        if fullname == self.plugin_source.namespace:
            # Handle the namespace package itself
            print(f"PluginFinder: Creating namespace package '{fullname}'")
            spec = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True)
            spec.submodule_search_locations = []
            return spec

        elif fullname.startswith(self.plugin_source.namespace + '.'):
            # Handle submodules (plugins)
            plugin_name = fullname[len(self.plugin_source.namespace) + 1:]
            if plugin_name in self.plugin_source.list_plugins():
                print(f"PluginFinder: Found plugin '{plugin_name}' for module '{fullname}'")
                loader = PluginLoader(self.plugin_source, plugin_name)
                spec = importlib.util.spec_from_loader(fullname, loader)
                spec.submodule_search_locations = []
                return spec

        # If not handling this module, return None
        print(f"PluginFinder: Not handling module '{fullname}'")
        return None
find_spec(fullname, path, target=None)

Find the module spec for the given module. Handles both the namespace package and its submodules (plugins).

Source code in src/aeiva/plugin/plug.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def find_spec(
    self,
    fullname: str,
    path: Optional[List[str]],
    target: Optional[ModuleType] = None
) -> Optional[importlib.machinery.ModuleSpec]:
    """
    Find the module spec for the given module.
    Handles both the namespace package and its submodules (plugins).
    """
    if fullname == self.plugin_source.namespace:
        # Handle the namespace package itself
        print(f"PluginFinder: Creating namespace package '{fullname}'")
        spec = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True)
        spec.submodule_search_locations = []
        return spec

    elif fullname.startswith(self.plugin_source.namespace + '.'):
        # Handle submodules (plugins)
        plugin_name = fullname[len(self.plugin_source.namespace) + 1:]
        if plugin_name in self.plugin_source.list_plugins():
            print(f"PluginFinder: Found plugin '{plugin_name}' for module '{fullname}'")
            loader = PluginLoader(self.plugin_source, plugin_name)
            spec = importlib.util.spec_from_loader(fullname, loader)
            spec.submodule_search_locations = []
            return spec

    # If not handling this module, return None
    print(f"PluginFinder: Not handling module '{fullname}'")
    return None

PluginLoader

Bases: Loader

Custom loader for plugin modules. Loads the plugin.py file within the plugin directory.

Source code in src/aeiva/plugin/plug.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class PluginLoader(importlib.abc.Loader):
    """
    Custom loader for plugin modules.
    Loads the `plugin.py` file within the plugin directory.
    """

    def __init__(self, plugin_source: 'PluginSource', plugin_name: str) -> None:
        self.plugin_source = plugin_source
        self.plugin_name = plugin_name

    def create_module(self, spec: importlib.machinery.ModuleSpec) -> Optional[ModuleType]:
        """Use default module creation semantics."""
        return None

    def exec_module(self, module: ModuleType) -> None:
        """Execute the plugin's `plugin.py` module."""
        try:
            code = self.plugin_source.get_plugin_code(self.plugin_name)
        except ImportError as e:
            print(f"PluginLoader: Failed to get code for plugin '{self.plugin_name}': {e}")
            raise

        # Compute project_root dynamically based on plug.py's location
        plugin_dir = os.path.dirname(os.path.abspath(__file__))
        project_root = os.path.abspath(os.path.join(plugin_dir, '../../../'))
        print(f"PluginLoader: Adding '{project_root}' to sys.path for plugin '{self.plugin_name}'")
        sys.path.insert(0, project_root)

        try:
            print(f"PluginLoader: Executing plugin '{self.plugin_name}'")
            exec(code, module.__dict__)
            print(f"PluginLoader: Plugin '{self.plugin_name}' executed successfully")
        except Exception as e:
            print(f"PluginLoader: Error executing plugin '{self.plugin_name}': {e}")
            raise
        finally:
            sys.path.pop(0)
create_module(spec)

Use default module creation semantics.

Source code in src/aeiva/plugin/plug.py
56
57
58
def create_module(self, spec: importlib.machinery.ModuleSpec) -> Optional[ModuleType]:
    """Use default module creation semantics."""
    return None
exec_module(module)

Execute the plugin's plugin.py module.

Source code in src/aeiva/plugin/plug.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def exec_module(self, module: ModuleType) -> None:
    """Execute the plugin's `plugin.py` module."""
    try:
        code = self.plugin_source.get_plugin_code(self.plugin_name)
    except ImportError as e:
        print(f"PluginLoader: Failed to get code for plugin '{self.plugin_name}': {e}")
        raise

    # Compute project_root dynamically based on plug.py's location
    plugin_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.abspath(os.path.join(plugin_dir, '../../../'))
    print(f"PluginLoader: Adding '{project_root}' to sys.path for plugin '{self.plugin_name}'")
    sys.path.insert(0, project_root)

    try:
        print(f"PluginLoader: Executing plugin '{self.plugin_name}'")
        exec(code, module.__dict__)
        print(f"PluginLoader: Plugin '{self.plugin_name}' executed successfully")
    except Exception as e:
        print(f"PluginLoader: Error executing plugin '{self.plugin_name}': {e}")
        raise
    finally:
        sys.path.pop(0)

PluginManager

Manages multiple PluginSources and controls plugin imports.

Source code in src/aeiva/plugin/plug.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class PluginManager:
    """
    Manages multiple PluginSources and controls plugin imports.
    """

    def __init__(self) -> None:
        self.plugin_sources: Dict[str, PluginSource] = {}

    def create_plugin_source(self, name: str, search_path: Optional[List[str]] = None) -> PluginSource:
        """
        Creates a new PluginSource.

        :param name: Unique name for the plugin source.
        :param search_path: List of paths to search for plugins.
        :return: The created PluginSource.
        """
        if name in self.plugin_sources:
            raise ValueError(f"Plugin source '{name}' already exists.")
        source = PluginSource(name, search_path)
        self.plugin_sources[name] = source
        print(f"PluginManager: Created plugin source '{name}' with search paths {search_path}.")
        return source

    def get_plugin_source(self, name: str) -> Optional[PluginSource]:
        """
        Retrieves a PluginSource by name.

        :param name: Name of the PluginSource.
        :return: The PluginSource instance, or None if not found.
        """
        return self.plugin_sources.get(name)

    def remove_plugin_source(self, name: str) -> None:
        """
        Removes a PluginSource.

        :param name: Name of the PluginSource to remove.
        """
        source = self.plugin_sources.pop(name, None)
        if source:
            source.disable()
            for plugin_name in list(source._modules.keys()):
                source.unload_plugin(plugin_name)
            print(f"PluginManager: Removed plugin source '{name}'.")
        else:
            print(f"PluginManager: Plugin source '{name}' does not exist.")
create_plugin_source(name, search_path=None)

Creates a new PluginSource.

:param name: Unique name for the plugin source. :param search_path: List of paths to search for plugins. :return: The created PluginSource.

Source code in src/aeiva/plugin/plug.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def create_plugin_source(self, name: str, search_path: Optional[List[str]] = None) -> PluginSource:
    """
    Creates a new PluginSource.

    :param name: Unique name for the plugin source.
    :param search_path: List of paths to search for plugins.
    :return: The created PluginSource.
    """
    if name in self.plugin_sources:
        raise ValueError(f"Plugin source '{name}' already exists.")
    source = PluginSource(name, search_path)
    self.plugin_sources[name] = source
    print(f"PluginManager: Created plugin source '{name}' with search paths {search_path}.")
    return source
get_plugin_source(name)

Retrieves a PluginSource by name.

:param name: Name of the PluginSource. :return: The PluginSource instance, or None if not found.

Source code in src/aeiva/plugin/plug.py
322
323
324
325
326
327
328
329
def get_plugin_source(self, name: str) -> Optional[PluginSource]:
    """
    Retrieves a PluginSource by name.

    :param name: Name of the PluginSource.
    :return: The PluginSource instance, or None if not found.
    """
    return self.plugin_sources.get(name)
remove_plugin_source(name)

Removes a PluginSource.

:param name: Name of the PluginSource to remove.

Source code in src/aeiva/plugin/plug.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def remove_plugin_source(self, name: str) -> None:
    """
    Removes a PluginSource.

    :param name: Name of the PluginSource to remove.
    """
    source = self.plugin_sources.pop(name, None)
    if source:
        source.disable()
        for plugin_name in list(source._modules.keys()):
            source.unload_plugin(plugin_name)
        print(f"PluginManager: Removed plugin source '{name}'.")
    else:
        print(f"PluginManager: Plugin source '{name}' does not exist.")

PluginSource

Represents an isolated source of plugins. Each plugin is a directory containing a plugin.py file.

Source code in src/aeiva/plugin/plug.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class PluginSource:
    """
    Represents an isolated source of plugins.
    Each plugin is a directory containing a `plugin.py` file.
    """

    def __init__(self, name: str, search_path: Optional[List[str]] = None) -> None:
        """
        Initializes the PluginSource.

        :param name: Unique name for the plugin source.
        :param search_path: List of paths (directories or zip files) to search for plugins.
        """
        self.name = name
        self.search_path = search_path or []
        self._lock = threading.Lock()
        self._modules: Dict[str, ModuleType] = {}
        self.namespace = f"_plug_{self.name}"
        self._finder = PluginFinder(self)
        self._finder_enabled = False

    def __enter__(self) -> 'PluginSource':
        """Enter the runtime context related to this object."""
        self.enable()
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        """Exit the runtime context."""
        self.disable()

    def enable(self) -> None:
        """Enable the plugin import mechanism."""
        if not self._finder_enabled:
            sys.meta_path.insert(0, self._finder)
            self._finder_enabled = True
            print(f"PluginSource: Import hook enabled for namespace '{self.namespace}'.")

    def disable(self) -> None:
        """Disable the plugin import mechanism."""
        if self._finder_enabled:
            try:
                sys.meta_path.remove(self._finder)
                print(f"PluginSource: Import hook disabled for namespace '{self.namespace}'.")
            except ValueError:
                print(f"PluginSource: Import hook for namespace '{self.namespace}' was not found in sys.meta_path.")
            self._finder_enabled = False

    def list_plugins(self) -> List[str]:
        """
        Lists available plugins in the search paths.
        Each plugin is a directory containing a `plugin.py` file.

        :return: List of plugin names.
        """
        plugins = set()
        for path in self.search_path:
            if zipfile.is_zipfile(path):
                with zipfile.ZipFile(path, 'r') as z:
                    # Identify top-level directories containing `plugin.py`
                    plugin_dirs = set()
                    for file in z.namelist():
                        parts = file.split('/')
                        if len(parts) >= 2 and parts[-1] == 'plugin.py':
                            plugin_dir = parts[0]
                            plugin_dirs.add(plugin_dir)
                    plugins.update(plugin_dirs)
            else:
                # Assume it's a directory
                if not os.path.isdir(path):
                    print(f"PluginSource: Path '{path}' is not a directory or a zip file. Skipping.")
                    continue
                for entry in os.listdir(path):
                    plugin_path = os.path.join(path, entry)
                    if os.path.isdir(plugin_path):
                        plugin_main = os.path.join(plugin_path, 'plugin.py')
                        if os.path.isfile(plugin_main):
                            plugins.add(entry)
        return list(plugins)

    def get_plugin_code(self, plugin_name: str) -> str:
        """
        Get the source code of the plugin's `plugin.py`.

        :param plugin_name: Name of the plugin to load.
        :return: Source code of `plugin.py` as a string.
        """
        for path in self.search_path:
            if zipfile.is_zipfile(path):
                with zipfile.ZipFile(path, 'r') as z:
                    plugin_main = f"{plugin_name}/plugin.py"
                    if plugin_main in z.namelist():
                        print(f"PluginSource: Found plugin '{plugin_name}' in zip file '{path}'.")
                        return z.read(plugin_main).decode('utf-8')
            else:
                # Assume it's a directory
                plugin_dir = os.path.join(path, plugin_name)
                plugin_main = os.path.join(plugin_dir, 'plugin.py')
                if os.path.isfile(plugin_main):
                    print(f"PluginSource: Found plugin '{plugin_name}' as module file '{plugin_main}'.")
                    with open(plugin_main, 'r', encoding='utf-8') as f:
                        return f.read()
        raise ImportError(f"Cannot find plugin '{plugin_name}'.")

    def load_plugin(self, plugin_name: str) -> ModuleType:
        """
        Loads a plugin by name.

        :param plugin_name: Name of the plugin to load.
        :return: The loaded plugin module.
        """
        with self._lock:
            full_name = f"{self.namespace}.{plugin_name}"
            if full_name in sys.modules:
                print(f"PluginSource: Plugin '{plugin_name}' is already loaded as '{full_name}'.")
                return sys.modules[full_name]
            # Enable the finder if not already enabled
            self.enable()
            try:
                print(f"PluginSource: Loading plugin '{plugin_name}' as '{full_name}'.")
                module = importlib.import_module(full_name)
                self._modules[plugin_name] = module
                return module
            except ImportError as e:
                print(f"PluginSource: Cannot import plugin '{plugin_name}': {e}")
                raise

    def unload_plugin(self, plugin_name: str) -> None:
        """
        Unloads a plugin by name.

        :param plugin_name: Name of the plugin to unload.
        """
        with self._lock:
            full_name = f"{self.namespace}.{plugin_name}"
            module = self._modules.pop(plugin_name, None)
            if module:
                if hasattr(module, 'deactivate'):
                    try:
                        print(f"PluginSource: Deactivating plugin '{plugin_name}'.")
                        getattr(module, 'deactivate')()
                    except Exception as e:
                        print(f"PluginSource: Error during deactivation of plugin '{plugin_name}': {e}")
                if full_name in sys.modules:
                    del sys.modules[full_name]
                    print(f"PluginSource: Plugin '{plugin_name}' unloaded and removed from sys.modules.")
            else:
                print(f"PluginSource: Plugin '{plugin_name}' is not loaded.")

    def load_resource(self, plugin_name: str, resource_name: str) -> bytes:
        """
        Loads a resource from a plugin.

        :param plugin_name: Name of the plugin.
        :param resource_name: Name of the resource file.
        :return: Contents of the resource file as bytes.
        """
        for path in self.search_path:
            if zipfile.is_zipfile(path):
                with zipfile.ZipFile(path, 'r') as z:
                    resource_file = f"{plugin_name}/{resource_name}"
                    if resource_file in z.namelist():
                        print(f"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' in zip '{path}'.")
                        return z.read(resource_file)
            else:
                # Assume it's a directory
                resource_path = os.path.join(path, plugin_name, resource_name)
                if os.path.isfile(resource_path):
                    print(f"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' at '{resource_path}'.")
                    with open(resource_path, 'rb') as f:
                        return f.read()
        raise FileNotFoundError(f"Resource '{resource_name}' not found in plugin '{plugin_name}'.")
__enter__()

Enter the runtime context related to this object.

Source code in src/aeiva/plugin/plug.py
147
148
149
150
def __enter__(self) -> 'PluginSource':
    """Enter the runtime context related to this object."""
    self.enable()
    return self
__exit__(exc_type, exc_value, traceback)

Exit the runtime context.

Source code in src/aeiva/plugin/plug.py
152
153
154
def __exit__(self, exc_type, exc_value, traceback) -> None:
    """Exit the runtime context."""
    self.disable()
__init__(name, search_path=None)

Initializes the PluginSource.

:param name: Unique name for the plugin source. :param search_path: List of paths (directories or zip files) to search for plugins.

Source code in src/aeiva/plugin/plug.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def __init__(self, name: str, search_path: Optional[List[str]] = None) -> None:
    """
    Initializes the PluginSource.

    :param name: Unique name for the plugin source.
    :param search_path: List of paths (directories or zip files) to search for plugins.
    """
    self.name = name
    self.search_path = search_path or []
    self._lock = threading.Lock()
    self._modules: Dict[str, ModuleType] = {}
    self.namespace = f"_plug_{self.name}"
    self._finder = PluginFinder(self)
    self._finder_enabled = False
disable()

Disable the plugin import mechanism.

Source code in src/aeiva/plugin/plug.py
163
164
165
166
167
168
169
170
171
def disable(self) -> None:
    """Disable the plugin import mechanism."""
    if self._finder_enabled:
        try:
            sys.meta_path.remove(self._finder)
            print(f"PluginSource: Import hook disabled for namespace '{self.namespace}'.")
        except ValueError:
            print(f"PluginSource: Import hook for namespace '{self.namespace}' was not found in sys.meta_path.")
        self._finder_enabled = False
enable()

Enable the plugin import mechanism.

Source code in src/aeiva/plugin/plug.py
156
157
158
159
160
161
def enable(self) -> None:
    """Enable the plugin import mechanism."""
    if not self._finder_enabled:
        sys.meta_path.insert(0, self._finder)
        self._finder_enabled = True
        print(f"PluginSource: Import hook enabled for namespace '{self.namespace}'.")
get_plugin_code(plugin_name)

Get the source code of the plugin's plugin.py.

:param plugin_name: Name of the plugin to load. :return: Source code of plugin.py as a string.

Source code in src/aeiva/plugin/plug.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def get_plugin_code(self, plugin_name: str) -> str:
    """
    Get the source code of the plugin's `plugin.py`.

    :param plugin_name: Name of the plugin to load.
    :return: Source code of `plugin.py` as a string.
    """
    for path in self.search_path:
        if zipfile.is_zipfile(path):
            with zipfile.ZipFile(path, 'r') as z:
                plugin_main = f"{plugin_name}/plugin.py"
                if plugin_main in z.namelist():
                    print(f"PluginSource: Found plugin '{plugin_name}' in zip file '{path}'.")
                    return z.read(plugin_main).decode('utf-8')
        else:
            # Assume it's a directory
            plugin_dir = os.path.join(path, plugin_name)
            plugin_main = os.path.join(plugin_dir, 'plugin.py')
            if os.path.isfile(plugin_main):
                print(f"PluginSource: Found plugin '{plugin_name}' as module file '{plugin_main}'.")
                with open(plugin_main, 'r', encoding='utf-8') as f:
                    return f.read()
    raise ImportError(f"Cannot find plugin '{plugin_name}'.")
list_plugins()

Lists available plugins in the search paths. Each plugin is a directory containing a plugin.py file.

:return: List of plugin names.

Source code in src/aeiva/plugin/plug.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def list_plugins(self) -> List[str]:
    """
    Lists available plugins in the search paths.
    Each plugin is a directory containing a `plugin.py` file.

    :return: List of plugin names.
    """
    plugins = set()
    for path in self.search_path:
        if zipfile.is_zipfile(path):
            with zipfile.ZipFile(path, 'r') as z:
                # Identify top-level directories containing `plugin.py`
                plugin_dirs = set()
                for file in z.namelist():
                    parts = file.split('/')
                    if len(parts) >= 2 and parts[-1] == 'plugin.py':
                        plugin_dir = parts[0]
                        plugin_dirs.add(plugin_dir)
                plugins.update(plugin_dirs)
        else:
            # Assume it's a directory
            if not os.path.isdir(path):
                print(f"PluginSource: Path '{path}' is not a directory or a zip file. Skipping.")
                continue
            for entry in os.listdir(path):
                plugin_path = os.path.join(path, entry)
                if os.path.isdir(plugin_path):
                    plugin_main = os.path.join(plugin_path, 'plugin.py')
                    if os.path.isfile(plugin_main):
                        plugins.add(entry)
    return list(plugins)
load_plugin(plugin_name)

Loads a plugin by name.

:param plugin_name: Name of the plugin to load. :return: The loaded plugin module.

Source code in src/aeiva/plugin/plug.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def load_plugin(self, plugin_name: str) -> ModuleType:
    """
    Loads a plugin by name.

    :param plugin_name: Name of the plugin to load.
    :return: The loaded plugin module.
    """
    with self._lock:
        full_name = f"{self.namespace}.{plugin_name}"
        if full_name in sys.modules:
            print(f"PluginSource: Plugin '{plugin_name}' is already loaded as '{full_name}'.")
            return sys.modules[full_name]
        # Enable the finder if not already enabled
        self.enable()
        try:
            print(f"PluginSource: Loading plugin '{plugin_name}' as '{full_name}'.")
            module = importlib.import_module(full_name)
            self._modules[plugin_name] = module
            return module
        except ImportError as e:
            print(f"PluginSource: Cannot import plugin '{plugin_name}': {e}")
            raise
load_resource(plugin_name, resource_name)

Loads a resource from a plugin.

:param plugin_name: Name of the plugin. :param resource_name: Name of the resource file. :return: Contents of the resource file as bytes.

Source code in src/aeiva/plugin/plug.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def load_resource(self, plugin_name: str, resource_name: str) -> bytes:
    """
    Loads a resource from a plugin.

    :param plugin_name: Name of the plugin.
    :param resource_name: Name of the resource file.
    :return: Contents of the resource file as bytes.
    """
    for path in self.search_path:
        if zipfile.is_zipfile(path):
            with zipfile.ZipFile(path, 'r') as z:
                resource_file = f"{plugin_name}/{resource_name}"
                if resource_file in z.namelist():
                    print(f"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' in zip '{path}'.")
                    return z.read(resource_file)
        else:
            # Assume it's a directory
            resource_path = os.path.join(path, plugin_name, resource_name)
            if os.path.isfile(resource_path):
                print(f"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' at '{resource_path}'.")
                with open(resource_path, 'rb') as f:
                    return f.read()
    raise FileNotFoundError(f"Resource '{resource_name}' not found in plugin '{plugin_name}'.")
unload_plugin(plugin_name)

Unloads a plugin by name.

:param plugin_name: Name of the plugin to unload.

Source code in src/aeiva/plugin/plug.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def unload_plugin(self, plugin_name: str) -> None:
    """
    Unloads a plugin by name.

    :param plugin_name: Name of the plugin to unload.
    """
    with self._lock:
        full_name = f"{self.namespace}.{plugin_name}"
        module = self._modules.pop(plugin_name, None)
        if module:
            if hasattr(module, 'deactivate'):
                try:
                    print(f"PluginSource: Deactivating plugin '{plugin_name}'.")
                    getattr(module, 'deactivate')()
                except Exception as e:
                    print(f"PluginSource: Error during deactivation of plugin '{plugin_name}': {e}")
            if full_name in sys.modules:
                del sys.modules[full_name]
                print(f"PluginSource: Plugin '{plugin_name}' unloaded and removed from sys.modules.")
        else:
            print(f"PluginSource: Plugin '{plugin_name}' is not loaded.")

test

Main Application

This script demonstrates the usage of the plug module and plugin system.

society

society

Society

Bases: ABC

Abstract base class representing a Society that connects an environment and agents.

The Society enables agents to interact with each other and with the environment, providing mechanisms for integrating social systems, such as communication or economy.

Attributes:

Name Type Description
config Any

Configuration settings for the society.

environment Environment

The environment in which agents operate.

agents Dict[str, Any]

A dictionary of agents within the society.

social_systems Dict[str, Any]

A dictionary representing various social systems (e.g., communication).

Source code in src/aeiva/society/society.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class Society(ABC):
    """
    Abstract base class representing a Society that connects an environment and agents.

    The Society enables agents to interact with each other and with the environment, providing
    mechanisms for integrating social systems, such as communication or economy.

    Attributes:
        config (Any): Configuration settings for the society.
        environment (Environment): The environment in which agents operate.
        agents (Dict[str, Any]): A dictionary of agents within the society.
        social_systems (Dict[str, Any]): A dictionary representing various social systems (e.g., communication).
    """

    def __init__(self, config: Any, environment: Any, agents: Dict[str, Any]):
        """
        Initialize the Society with the provided configuration, environment, and agents.

        Args:
            config (Any): Configuration settings for the society.
            env (Environment): The environment in which agents operate.
            agents (Dict[str, Any]): A dictionary of agents within the society, keyed by their IDs.
        """
        self.config = config
        self.environment = environment
        self.agents = agents  # Agents are stored in a dictionary with IDs as keys
        self.social_systems = self.init_social_systems()

    @abstractmethod
    def init_social_systems(self) -> Dict[str, Any]:
        """
        Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).

        Returns:
            Dict[str, Any]: A dictionary of initialized social systems.
        """
        pass

    @abstractmethod
    async def setup(self) -> None:
        """
        Asynchronously set up the society's components, such as initializing the environment and agents.
        """
        await self.env.setup()
        await asyncio.gather(*(agent.setup() for agent in self.agents.values()))
        print("Society: Setup completed.")

    @abstractmethod
    async def run(self) -> None:
        """
        Asynchronously run the society, managing interactions between agents and the environment.

        This method should control the flow of interactions between agents and the environment,
        and it can be designed as a continuous loop or a task-based execution.
        """
        pass

    def add_agent(self, agent_id: str, agent: Any) -> None:
        """
        Add a new agent to the society.

        Args:
            agent_id (str): The unique identifier of the agent.
            agent (Any): The agent object to add to the society.
        """
        self.agents[agent_id] = agent

    def remove_agent(self, agent_id: str) -> None:
        """
        Remove an agent from the society by its ID.

        Args:
            agent_id (str): The unique identifier of the agent.
        """
        if agent_id in self.agents:
            del self.agents[agent_id]

    def get_agent(self, agent_id: str) -> Any:
        """
        Retrieve an agent by its ID.

        Args:
            agent_id (str): The unique identifier of the agent.

        Returns:
            Any: The agent object, if found.
        """
        return self.agents.get(agent_id, None)

    def handle_error(self, error: Exception) -> None:
        """
        Handle errors that occur during society operations.

        Args:
            error (Exception): The exception that was raised.
        """
        print(f"Society encountered an error: {error}")
__init__(config, environment, agents)

Initialize the Society with the provided configuration, environment, and agents.

Parameters:

Name Type Description Default
config Any

Configuration settings for the society.

required
env Environment

The environment in which agents operate.

required
agents Dict[str, Any]

A dictionary of agents within the society, keyed by their IDs.

required
Source code in src/aeiva/society/society.py
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(self, config: Any, environment: Any, agents: Dict[str, Any]):
    """
    Initialize the Society with the provided configuration, environment, and agents.

    Args:
        config (Any): Configuration settings for the society.
        env (Environment): The environment in which agents operate.
        agents (Dict[str, Any]): A dictionary of agents within the society, keyed by their IDs.
    """
    self.config = config
    self.environment = environment
    self.agents = agents  # Agents are stored in a dictionary with IDs as keys
    self.social_systems = self.init_social_systems()
add_agent(agent_id, agent)

Add a new agent to the society.

Parameters:

Name Type Description Default
agent_id str

The unique identifier of the agent.

required
agent Any

The agent object to add to the society.

required
Source code in src/aeiva/society/society.py
63
64
65
66
67
68
69
70
71
def add_agent(self, agent_id: str, agent: Any) -> None:
    """
    Add a new agent to the society.

    Args:
        agent_id (str): The unique identifier of the agent.
        agent (Any): The agent object to add to the society.
    """
    self.agents[agent_id] = agent
get_agent(agent_id)

Retrieve an agent by its ID.

Parameters:

Name Type Description Default
agent_id str

The unique identifier of the agent.

required

Returns:

Name Type Description
Any Any

The agent object, if found.

Source code in src/aeiva/society/society.py
83
84
85
86
87
88
89
90
91
92
93
def get_agent(self, agent_id: str) -> Any:
    """
    Retrieve an agent by its ID.

    Args:
        agent_id (str): The unique identifier of the agent.

    Returns:
        Any: The agent object, if found.
    """
    return self.agents.get(agent_id, None)
handle_error(error)

Handle errors that occur during society operations.

Parameters:

Name Type Description Default
error Exception

The exception that was raised.

required
Source code in src/aeiva/society/society.py
 95
 96
 97
 98
 99
100
101
102
def handle_error(self, error: Exception) -> None:
    """
    Handle errors that occur during society operations.

    Args:
        error (Exception): The exception that was raised.
    """
    print(f"Society encountered an error: {error}")
init_social_systems() abstractmethod

Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary of initialized social systems.

Source code in src/aeiva/society/society.py
34
35
36
37
38
39
40
41
42
@abstractmethod
def init_social_systems(self) -> Dict[str, Any]:
    """
    Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).

    Returns:
        Dict[str, Any]: A dictionary of initialized social systems.
    """
    pass
remove_agent(agent_id)

Remove an agent from the society by its ID.

Parameters:

Name Type Description Default
agent_id str

The unique identifier of the agent.

required
Source code in src/aeiva/society/society.py
73
74
75
76
77
78
79
80
81
def remove_agent(self, agent_id: str) -> None:
    """
    Remove an agent from the society by its ID.

    Args:
        agent_id (str): The unique identifier of the agent.
    """
    if agent_id in self.agents:
        del self.agents[agent_id]
run() abstractmethod async

Asynchronously run the society, managing interactions between agents and the environment.

This method should control the flow of interactions between agents and the environment, and it can be designed as a continuous loop or a task-based execution.

Source code in src/aeiva/society/society.py
53
54
55
56
57
58
59
60
61
@abstractmethod
async def run(self) -> None:
    """
    Asynchronously run the society, managing interactions between agents and the environment.

    This method should control the flow of interactions between agents and the environment,
    and it can be designed as a continuous loop or a task-based execution.
    """
    pass
setup() abstractmethod async

Asynchronously set up the society's components, such as initializing the environment and agents.

Source code in src/aeiva/society/society.py
44
45
46
47
48
49
50
51
@abstractmethod
async def setup(self) -> None:
    """
    Asynchronously set up the society's components, such as initializing the environment and agents.
    """
    await self.env.setup()
    await asyncio.gather(*(agent.setup() for agent in self.agents.values()))
    print("Society: Setup completed.")

storage

azure_ai_search_config

AzureAISearchConfig dataclass

Bases: BaseConfig

Configuration for Azure Cognitive Search vector database.

Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass
class AzureAISearchConfig(BaseConfig):
    """
    Configuration for Azure Cognitive Search vector database.
    """

    collection_name: str = field(
        default="mem0",
        metadata={"help": "Name of the collection (index name)."}
    )
    service_name: Optional[str] = field(
        default=None,
        metadata={"help": "Azure Cognitive Search service name."}
    )
    api_key: Optional[str] = field(
        default=None,
        metadata={"help": "API key for the Azure Cognitive Search service."}
    )
    embedding_model_dims: int = field(
        default=1536,
        metadata={"help": "Dimension of the embedding vector."}
    )
    use_compression: bool = field(
        default=False,
        metadata={"help": "Whether to use scalar quantization vector compression."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Validate that service_name and api_key are provided
        if not self.service_name or not self.api_key:
            raise ValueError("Both 'service_name' and 'api_key' must be provided.")

azure_ai_search_database

AzureAISearchDatabase

Bases: VectorDatabase

Concrete implementation of VectorStoreBase using Azure Cognitive Search.

Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class AzureAISearchDatabase(VectorDatabase):
    """
    Concrete implementation of VectorStoreBase using Azure Cognitive Search.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the Azure Cognitive Search vector store.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.index_name = config.get('collection_name')
        self.service_name = config.get('service_name')
        self.api_key = config.get('api_key')
        self.embedding_model_dims = config.get('embedding_model_dims')
        self.use_compression = config.get('use_compression', False)

        if not all([self.service_name, self.api_key, self.index_name, self.embedding_model_dims]):
            raise ValueError("Required configuration parameters are missing.")

        self.create_client(
            uri=None,
            service_name=self.service_name,
            api_key=self.api_key
        )
        self.create_collection(
            collection_name=self.index_name,
            vector_size=self.embedding_model_dims,
            distance_metric='cosine'
        )

    def create_client(
        self,
        uri: Optional[str] = None,
        service_name: Optional[str] = None,
        api_key: Optional[str] = None,
        **kwargs
    ) -> None:
        """
        Initializes the client connection to the vector store.

        Args:
            uri (Optional[str]): Not used for Azure Cognitive Search.
            service_name (str): Azure Cognitive Search service name.
            api_key (str): API key for the Azure Cognitive Search service.
            **kwargs: Additional parameters.
        """
        if not service_name or not api_key:
            raise ValueError("Both 'service_name' and 'api_key' must be provided.")

        endpoint = f"https://{service_name}.search.windows.net"
        credential = AzureKeyCredential(api_key)
        self.search_client = SearchClient(
            endpoint=endpoint,
            index_name=self.index_name,
            credential=credential
        )
        self.index_client = SearchIndexClient(
            endpoint=endpoint,
            credential=credential
        )

    def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
        """
        Create a new vector collection (index) in Azure Cognitive Search.

        Args:
            collection_name (str): The name of the collection.
            vector_size (int): The dimensionality of the vectors.
            distance_metric (str): The distance metric to use (e.g., 'cosine').
        """
        # Check if the index already exists
        try:
            self.index_client.get_index(collection_name)
            logger.info(f"Index {collection_name} already exists. Skipping creation.")
            return
        except ResourceNotFoundError:
            pass  # Index does not exist, proceed to create

        if self.use_compression:
            vector_type = "Collection(Edm.Half)"
            compression_name = "myCompression"
            compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
        else:
            vector_type = "Collection(Edm.Single)"
            compression_name = None
            compression_configurations = []

        fields = [
            SimpleField(name="id", type=SearchFieldDataType.String, key=True),
            SearchField(
                name="vector",
                type=vector_type,
                searchable=True,
                vector_search_dimensions=vector_size,
                vector_search_profile_name="my-vector-config",
            ),
            SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
        ]

        vector_search = VectorSearch(
            profiles=[
                VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
            ],
            algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
            compressions=compression_configurations,
        )
        index = SearchIndex(name=collection_name, fields=fields, vector_search=vector_search)
        self.index_client.create_or_update_index(index)
        logger.info(f"Index {collection_name} created successfully.")

    def insert_vectors(
        self,
        collection_name: str,
        vectors: List[List[float]],
        payloads: Optional[List[Dict[str, Any]]] = None,
        ids: Optional[List[str]] = None
    ) -> None:
        """
        Insert vectors into the index.

        Args:
            collection_name (str): The name of the collection.
            vectors (List[List[float]]): A list of vectors to insert.
            payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
            ids (Optional[List[str]]): Optional unique identifiers for each vector.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        if ids is None:
            ids = [str(i) for i in range(len(vectors))]
        if payloads is None:
            payloads = [{} for _ in range(len(vectors))]
        if not (len(ids) == len(vectors) == len(payloads)):
            raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

        documents = [
            {"id": id_, "vector": vector, "payload": json.dumps(payload)}
            for id_, vector, payload in zip(ids, vectors, payloads)
        ]
        self.search_client.upload_documents(documents)
        logger.info(f"Inserted {len(vectors)} vectors into index {collection_name}.")

    def search_vectors(
        self,
        collection_name: str,
        query_vector: List[float],
        top_k: int = 5,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """
        Search for similar vectors in a collection.

        Args:
            collection_name (str): The name of the collection.
            query_vector (List[float]): The vector to search with.
            top_k (int): The number of top results to return.
            filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

        Returns:
            List[Dict[str, Any]]: A list of search results.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        vector_query = VectorizedQuery(vector=query_vector, k_nearest_neighbors=top_k, fields="vector")
        search_results = self.search_client.search(vector_queries=[vector_query], top=top_k)

        results = []
        for result in search_results:
            payload = json.loads(result["payload"])
            if filters:
                for key, value in filters.items():
                    if key not in payload or payload[key] != value:
                        continue
            result_dict = {
                "id": result["id"],
                "score": result["@search.score"],
                "payload": payload
            }
            results.append(result_dict)
        return results

    def delete_vector(self, collection_name: str, vector_id: str) -> None:
        """
        Delete a vector from a collection by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to delete.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")
        self.search_client.delete_documents(documents=[{"id": vector_id}])
        logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")

    def update_vector(
        self,
        collection_name: str,
        vector_id: str,
        vector: Optional[List[float]] = None,
        payload: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Update a vector's data or payload.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to update.
            vector (Optional[List[float]]): The new vector data.
            payload (Optional[Dict[str, Any]]): The new payload data.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        document = {"id": vector_id}
        if vector is not None:
            document["vector"] = vector
        if payload is not None:
            document["payload"] = json.dumps(payload)
        self.search_client.merge_or_upload_documents(documents=[document])
        logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

    def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
        """
        Retrieve a vector by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector.

        Returns:
            Dict[str, Any]: A dictionary containing the vector data and payload.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")
        try:
            result = self.search_client.get_document(key=vector_id)
            payload = json.loads(result["payload"])
            vector_data = {
                "id": result["id"],
                "vector": result["vector"],
                "payload": payload
            }
            return vector_data
        except ResourceNotFoundError:
            raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

    def list_collections(self) -> List[str]:
        """
        List all available vector collections.

        Returns:
            List[str]: A list of collection names.
        """
        indexes = self.index_client.list_indexes()
        return [index.name for index in indexes]

    def delete_collection(self, collection_name: str) -> None:
        """
        Delete an entire vector collection.

        Args:
            collection_name (str): The name of the collection to delete.
        """
        self.index_client.delete_index(collection_name)
        logger.info(f"Deleted collection {collection_name}.")

    def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        """
        Get information about a collection.

        Args:
            collection_name (str): The name of the collection.

        Returns:
            Dict[str, Any]: Information about the collection.
        """
        index = self.index_client.get_index(collection_name)
        return {
            "name": index.name,
            "fields": [field.name for field in index.fields],
            "vector_search": index.vector_search
        }

    def __del__(self):
        """Clean up resources."""
        self.search_client.close()
        self.index_client.close()
__del__()

Clean up resources.

Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
318
319
320
321
def __del__(self):
    """Clean up resources."""
    self.search_client.close()
    self.index_client.close()
__init__(config)

Initialize the Azure Cognitive Search vector store.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the Azure Cognitive Search vector store.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.index_name = config.get('collection_name')
    self.service_name = config.get('service_name')
    self.api_key = config.get('api_key')
    self.embedding_model_dims = config.get('embedding_model_dims')
    self.use_compression = config.get('use_compression', False)

    if not all([self.service_name, self.api_key, self.index_name, self.embedding_model_dims]):
        raise ValueError("Required configuration parameters are missing.")

    self.create_client(
        uri=None,
        service_name=self.service_name,
        api_key=self.api_key
    )
    self.create_collection(
        collection_name=self.index_name,
        vector_size=self.embedding_model_dims,
        distance_metric='cosine'
    )
create_client(uri=None, service_name=None, api_key=None, **kwargs)

Initializes the client connection to the vector store.

Parameters:

Name Type Description Default
uri Optional[str]

Not used for Azure Cognitive Search.

None
service_name str

Azure Cognitive Search service name.

None
api_key str

API key for the Azure Cognitive Search service.

None
**kwargs

Additional parameters.

{}
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def create_client(
    self,
    uri: Optional[str] = None,
    service_name: Optional[str] = None,
    api_key: Optional[str] = None,
    **kwargs
) -> None:
    """
    Initializes the client connection to the vector store.

    Args:
        uri (Optional[str]): Not used for Azure Cognitive Search.
        service_name (str): Azure Cognitive Search service name.
        api_key (str): API key for the Azure Cognitive Search service.
        **kwargs: Additional parameters.
    """
    if not service_name or not api_key:
        raise ValueError("Both 'service_name' and 'api_key' must be provided.")

    endpoint = f"https://{service_name}.search.windows.net"
    credential = AzureKeyCredential(api_key)
    self.search_client = SearchClient(
        endpoint=endpoint,
        index_name=self.index_name,
        credential=credential
    )
    self.index_client = SearchIndexClient(
        endpoint=endpoint,
        credential=credential
    )
create_collection(collection_name, vector_size, distance_metric)

Create a new vector collection (index) in Azure Cognitive Search.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_size int

The dimensionality of the vectors.

required
distance_metric str

The distance metric to use (e.g., 'cosine').

required
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
    """
    Create a new vector collection (index) in Azure Cognitive Search.

    Args:
        collection_name (str): The name of the collection.
        vector_size (int): The dimensionality of the vectors.
        distance_metric (str): The distance metric to use (e.g., 'cosine').
    """
    # Check if the index already exists
    try:
        self.index_client.get_index(collection_name)
        logger.info(f"Index {collection_name} already exists. Skipping creation.")
        return
    except ResourceNotFoundError:
        pass  # Index does not exist, proceed to create

    if self.use_compression:
        vector_type = "Collection(Edm.Half)"
        compression_name = "myCompression"
        compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
    else:
        vector_type = "Collection(Edm.Single)"
        compression_name = None
        compression_configurations = []

    fields = [
        SimpleField(name="id", type=SearchFieldDataType.String, key=True),
        SearchField(
            name="vector",
            type=vector_type,
            searchable=True,
            vector_search_dimensions=vector_size,
            vector_search_profile_name="my-vector-config",
        ),
        SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
    ]

    vector_search = VectorSearch(
        profiles=[
            VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
        ],
        algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
        compressions=compression_configurations,
    )
    index = SearchIndex(name=collection_name, fields=fields, vector_search=vector_search)
    self.index_client.create_or_update_index(index)
    logger.info(f"Index {collection_name} created successfully.")
delete_collection(collection_name)

Delete an entire vector collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection to delete.

required
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
291
292
293
294
295
296
297
298
299
def delete_collection(self, collection_name: str) -> None:
    """
    Delete an entire vector collection.

    Args:
        collection_name (str): The name of the collection to delete.
    """
    self.index_client.delete_index(collection_name)
    logger.info(f"Deleted collection {collection_name}.")
delete_vector(collection_name, vector_id)

Delete a vector from a collection by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to delete.

required
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
216
217
218
219
220
221
222
223
224
225
226
227
def delete_vector(self, collection_name: str, vector_id: str) -> None:
    """
    Delete a vector from a collection by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to delete.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")
    self.search_client.delete_documents(documents=[{"id": vector_id}])
    logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")
get_collection_info(collection_name)

Get information about a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Information about the collection.

Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
    """
    Get information about a collection.

    Args:
        collection_name (str): The name of the collection.

    Returns:
        Dict[str, Any]: Information about the collection.
    """
    index = self.index_client.get_index(collection_name)
    return {
        "name": index.name,
        "fields": [field.name for field in index.fields],
        "vector_search": index.vector_search
    }
get_vector(collection_name, vector_id)

Retrieve a vector by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the vector data and payload.

Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
    """
    Retrieve a vector by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector.

    Returns:
        Dict[str, Any]: A dictionary containing the vector data and payload.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")
    try:
        result = self.search_client.get_document(key=vector_id)
        payload = json.loads(result["payload"])
        vector_data = {
            "id": result["id"],
            "vector": result["vector"],
            "payload": payload
        }
        return vector_data
    except ResourceNotFoundError:
        raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")
insert_vectors(collection_name, vectors, payloads=None, ids=None)

Insert vectors into the index.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vectors List[List[float]]

A list of vectors to insert.

required
payloads Optional[List[Dict[str, Any]]]

Optional metadata associated with each vector.

None
ids Optional[List[str]]

Optional unique identifiers for each vector.

None
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def insert_vectors(
    self,
    collection_name: str,
    vectors: List[List[float]],
    payloads: Optional[List[Dict[str, Any]]] = None,
    ids: Optional[List[str]] = None
) -> None:
    """
    Insert vectors into the index.

    Args:
        collection_name (str): The name of the collection.
        vectors (List[List[float]]): A list of vectors to insert.
        payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
        ids (Optional[List[str]]): Optional unique identifiers for each vector.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    if ids is None:
        ids = [str(i) for i in range(len(vectors))]
    if payloads is None:
        payloads = [{} for _ in range(len(vectors))]
    if not (len(ids) == len(vectors) == len(payloads)):
        raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

    documents = [
        {"id": id_, "vector": vector, "payload": json.dumps(payload)}
        for id_, vector, payload in zip(ids, vectors, payloads)
    ]
    self.search_client.upload_documents(documents)
    logger.info(f"Inserted {len(vectors)} vectors into index {collection_name}.")
list_collections()

List all available vector collections.

Returns:

Type Description
List[str]

List[str]: A list of collection names.

Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
281
282
283
284
285
286
287
288
289
def list_collections(self) -> List[str]:
    """
    List all available vector collections.

    Returns:
        List[str]: A list of collection names.
    """
    indexes = self.index_client.list_indexes()
    return [index.name for index in indexes]
search_vectors(collection_name, query_vector, top_k=5, filters=None)

Search for similar vectors in a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
query_vector List[float]

The vector to search with.

required
top_k int

The number of top results to return.

5
filters Optional[Dict[str, Any]]

Optional filters to apply to the search.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of search results.

Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def search_vectors(
    self,
    collection_name: str,
    query_vector: List[float],
    top_k: int = 5,
    filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
    """
    Search for similar vectors in a collection.

    Args:
        collection_name (str): The name of the collection.
        query_vector (List[float]): The vector to search with.
        top_k (int): The number of top results to return.
        filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

    Returns:
        List[Dict[str, Any]]: A list of search results.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    vector_query = VectorizedQuery(vector=query_vector, k_nearest_neighbors=top_k, fields="vector")
    search_results = self.search_client.search(vector_queries=[vector_query], top=top_k)

    results = []
    for result in search_results:
        payload = json.loads(result["payload"])
        if filters:
            for key, value in filters.items():
                if key not in payload or payload[key] != value:
                    continue
        result_dict = {
            "id": result["id"],
            "score": result["@search.score"],
            "payload": payload
        }
        results.append(result_dict)
    return results
update_vector(collection_name, vector_id, vector=None, payload=None)

Update a vector's data or payload.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to update.

required
vector Optional[List[float]]

The new vector data.

None
payload Optional[Dict[str, Any]]

The new payload data.

None
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def update_vector(
    self,
    collection_name: str,
    vector_id: str,
    vector: Optional[List[float]] = None,
    payload: Optional[Dict[str, Any]] = None
) -> None:
    """
    Update a vector's data or payload.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to update.
        vector (Optional[List[float]]): The new vector data.
        payload (Optional[Dict[str, Any]]): The new payload data.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    document = {"id": vector_id}
    if vector is not None:
        document["vector"] = vector
    if payload is not None:
        document["payload"] = json.dumps(payload)
    self.search_client.merge_or_upload_documents(documents=[document])
    logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

chroma

chroma_config

ChromaConfig dataclass

Bases: BaseConfig

Configuration for ChromaDB vector database.

Source code in src/aeiva/storage/chroma/chroma_config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass
class ChromaConfig(BaseConfig):
    """
    Configuration for ChromaDB vector database.
    """

    collection_name: str = field(
        default="mem0",
        metadata={"help": "Name of the collection."}
    )
    client: Optional[Any] = field(
        default=None,
        metadata={"help": "Existing ChromaDB client instance (if any)."}
    )
    path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the database directory for local storage."}
    )
    host: Optional[str] = field(
        default=None,
        metadata={"help": "Remote host address for ChromaDB."}
    )
    port: Optional[int] = field(
        default=None,
        metadata={"help": "Remote port for ChromaDB."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Validate that either path or host and port are provided
        if not self.path and not (self.host and self.port):
            raise ValueError("Either 'path' for local storage or both 'host' and 'port' for remote connection must be provided.")

chroma_database

ChromaDatabase

Bases: VectorDatabase

Concrete implementation of VectorStoreBase using ChromaDB.

Source code in src/aeiva/storage/chroma/chroma_database.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class ChromaDatabase(VectorDatabase):
    """
    Concrete implementation of VectorStoreBase using ChromaDB.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the ChromaDB vector store.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.collection_name = config.get('collection_name')
        self.client = config.get('client')
        self.host = config.get('host')
        self.port = config.get('port')
        self.path = config.get('path')

        if not self.collection_name:
            raise ValueError("Collection name must be provided in the configuration.")

        self.create_client(
            host=self.host,
            port=self.port,
            path=self.path
        )
        self.create_collection(
            collection_name=self.collection_name,
            vector_size=None,  # ChromaDB does not require specifying vector size upfront
            distance_metric='cosine'
        )

    def create_client(
        self,
        uri: Optional[str] = None,
        host: Optional[str] = None,
        port: Optional[int] = None,
        path: Optional[str] = None,
        **kwargs
    ) -> None:
        """
        Initializes the client connection to the vector store.

        Args:
            uri (Optional[str]): Not used for ChromaDB.
            host (Optional[str]): Host address for ChromaDB server.
            port (Optional[int]): Port for ChromaDB server.
            path (Optional[str]): Path to the database directory.
            **kwargs: Additional parameters.
        """
        if self.client:
            return  # Client already provided

        settings = Settings(anonymized_telemetry=False)

        if host and port:
            settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
            settings.chroma_server_host = host
            settings.chroma_server_http_port = port
        else:
            if not path:
                path = "db"
            settings.persist_directory = path
            settings.is_persistent = True

        self.client = chromadb.Client(settings)
        logger.info("ChromaDB client initialized.")

    def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
        """
        Create a new vector collection in ChromaDB.

        Args:
            collection_name (str): The name of the collection.
            vector_size (int): Not used for ChromaDB.
            distance_metric (str): Not used for ChromaDB.
        """
        # Check if collection exists
        existing_collections = self.list_collections()
        if collection_name in existing_collections:
            logger.info(f"Collection {collection_name} already exists. Skipping creation.")
            self.collection = self.client.get_collection(name=collection_name)
        else:
            self.collection = self.client.create_collection(name=collection_name)
            logger.info(f"Collection {collection_name} created successfully.")

    def insert_vectors(
        self,
        collection_name: str,
        vectors: List[List[float]],
        payloads: Optional[List[Dict[str, Any]]] = None,
        ids: Optional[List[str]] = None
    ) -> None:
        """
        Insert vectors into a collection.

        Args:
            collection_name (str): The name of the collection.
            vectors (List[List[float]]): A list of vectors to insert.
            payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
            ids (Optional[List[str]]): Optional unique identifiers for each vector.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        if ids is None:
            ids = [str(i) for i in range(len(vectors))]
        if payloads is None:
            payloads = [{} for _ in range(len(vectors))]
        if not (len(ids) == len(vectors) == len(payloads)):
            raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

        self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
        logger.info(f"Inserted {len(vectors)} vectors into collection {collection_name}.")

    def search_vectors(
        self,
        collection_name: str,
        query_vector: List[float],
        top_k: int = 5,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """
        Search for similar vectors in a collection.

        Args:
            collection_name (str): The name of the collection.
            query_vector (List[float]): The vector to search with.
            top_k (int): The number of top results to return.
            filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

        Returns:
            List[Dict[str, Any]]: A list of search results.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        results = self.collection.query(
            query_embeddings=[query_vector],
            where=filters,
            n_results=top_k
        )
        # Parse the results
        output = []
        for idx, (ids, distances, metadatas) in enumerate(zip(results['ids'], results['distances'], results['metadatas'])):
            for i in range(len(ids)):
                result = {
                    'id': ids[i],
                    'score': distances[i],
                    'payload': metadatas[i]
                }
                output.append(result)
        return output

    def delete_vector(self, collection_name: str, vector_id: str) -> None:
        """
        Delete a vector from a collection by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to delete.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        self.collection.delete(ids=[vector_id])
        logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")

    def update_vector(
        self,
        collection_name: str,
        vector_id: str,
        vector: Optional[List[float]] = None,
        payload: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Update a vector's data or payload.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to update.
            vector (Optional[List[float]]): The new vector data.
            payload (Optional[Dict[str, Any]]): The new payload data.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        self.collection.update(ids=[vector_id], embeddings=[vector] if vector else None, metadatas=[payload] if payload else None)
        logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

    def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
        """
        Retrieve a vector by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector.

        Returns:
            Dict[str, Any]: A dictionary containing the vector data and payload.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        result = self.collection.get(ids=[vector_id])
        if not result['ids']:
            raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

        vector_data = {
            'id': result['ids'][0],
            'vector': result['embeddings'][0] if 'embeddings' in result else None,
            'payload': result['metadatas'][0]
        }
        return vector_data

    def list_collections(self) -> List[str]:
        """
        List all available vector collections.

        Returns:
            List[str]: A list of collection names.
        """
        collections = self.client.list_collections()
        return [collection.name for collection in collections]

    def delete_collection(self, collection_name: str) -> None:
        """
        Delete an entire vector collection.

        Args:
            collection_name (str): The name of the collection to delete.
        """
        self.client.delete_collection(name=collection_name)
        logger.info(f"Deleted collection {collection_name}.")

    def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        """
        Get information about a collection.

        Args:
            collection_name (str): The name of the collection.

        Returns:
            Dict[str, Any]: Information about the collection.
        """
        collection = self.client.get_collection(name=collection_name)
        return {
            'name': collection.name,
            'metadata': collection.metadata
        }
__init__(config)

Initialize the ChromaDB vector store.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/chroma/chroma_database.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the ChromaDB vector store.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.collection_name = config.get('collection_name')
    self.client = config.get('client')
    self.host = config.get('host')
    self.port = config.get('port')
    self.path = config.get('path')

    if not self.collection_name:
        raise ValueError("Collection name must be provided in the configuration.")

    self.create_client(
        host=self.host,
        port=self.port,
        path=self.path
    )
    self.create_collection(
        collection_name=self.collection_name,
        vector_size=None,  # ChromaDB does not require specifying vector size upfront
        distance_metric='cosine'
    )
create_client(uri=None, host=None, port=None, path=None, **kwargs)

Initializes the client connection to the vector store.

Parameters:

Name Type Description Default
uri Optional[str]

Not used for ChromaDB.

None
host Optional[str]

Host address for ChromaDB server.

None
port Optional[int]

Port for ChromaDB server.

None
path Optional[str]

Path to the database directory.

None
**kwargs

Additional parameters.

{}
Source code in src/aeiva/storage/chroma/chroma_database.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def create_client(
    self,
    uri: Optional[str] = None,
    host: Optional[str] = None,
    port: Optional[int] = None,
    path: Optional[str] = None,
    **kwargs
) -> None:
    """
    Initializes the client connection to the vector store.

    Args:
        uri (Optional[str]): Not used for ChromaDB.
        host (Optional[str]): Host address for ChromaDB server.
        port (Optional[int]): Port for ChromaDB server.
        path (Optional[str]): Path to the database directory.
        **kwargs: Additional parameters.
    """
    if self.client:
        return  # Client already provided

    settings = Settings(anonymized_telemetry=False)

    if host and port:
        settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
        settings.chroma_server_host = host
        settings.chroma_server_http_port = port
    else:
        if not path:
            path = "db"
        settings.persist_directory = path
        settings.is_persistent = True

    self.client = chromadb.Client(settings)
    logger.info("ChromaDB client initialized.")
create_collection(collection_name, vector_size, distance_metric)

Create a new vector collection in ChromaDB.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_size int

Not used for ChromaDB.

required
distance_metric str

Not used for ChromaDB.

required
Source code in src/aeiva/storage/chroma/chroma_database.py
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
    """
    Create a new vector collection in ChromaDB.

    Args:
        collection_name (str): The name of the collection.
        vector_size (int): Not used for ChromaDB.
        distance_metric (str): Not used for ChromaDB.
    """
    # Check if collection exists
    existing_collections = self.list_collections()
    if collection_name in existing_collections:
        logger.info(f"Collection {collection_name} already exists. Skipping creation.")
        self.collection = self.client.get_collection(name=collection_name)
    else:
        self.collection = self.client.create_collection(name=collection_name)
        logger.info(f"Collection {collection_name} created successfully.")
delete_collection(collection_name)

Delete an entire vector collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection to delete.

required
Source code in src/aeiva/storage/chroma/chroma_database.py
240
241
242
243
244
245
246
247
248
def delete_collection(self, collection_name: str) -> None:
    """
    Delete an entire vector collection.

    Args:
        collection_name (str): The name of the collection to delete.
    """
    self.client.delete_collection(name=collection_name)
    logger.info(f"Deleted collection {collection_name}.")
delete_vector(collection_name, vector_id)

Delete a vector from a collection by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to delete.

required
Source code in src/aeiva/storage/chroma/chroma_database.py
169
170
171
172
173
174
175
176
177
178
179
180
181
def delete_vector(self, collection_name: str, vector_id: str) -> None:
    """
    Delete a vector from a collection by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to delete.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    self.collection.delete(ids=[vector_id])
    logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")
get_collection_info(collection_name)

Get information about a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Information about the collection.

Source code in src/aeiva/storage/chroma/chroma_database.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
    """
    Get information about a collection.

    Args:
        collection_name (str): The name of the collection.

    Returns:
        Dict[str, Any]: Information about the collection.
    """
    collection = self.client.get_collection(name=collection_name)
    return {
        'name': collection.name,
        'metadata': collection.metadata
    }
get_vector(collection_name, vector_id)

Retrieve a vector by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the vector data and payload.

Source code in src/aeiva/storage/chroma/chroma_database.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
    """
    Retrieve a vector by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector.

    Returns:
        Dict[str, Any]: A dictionary containing the vector data and payload.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    result = self.collection.get(ids=[vector_id])
    if not result['ids']:
        raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

    vector_data = {
        'id': result['ids'][0],
        'vector': result['embeddings'][0] if 'embeddings' in result else None,
        'payload': result['metadatas'][0]
    }
    return vector_data
insert_vectors(collection_name, vectors, payloads=None, ids=None)

Insert vectors into a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vectors List[List[float]]

A list of vectors to insert.

required
payloads Optional[List[Dict[str, Any]]]

Optional metadata associated with each vector.

None
ids Optional[List[str]]

Optional unique identifiers for each vector.

None
Source code in src/aeiva/storage/chroma/chroma_database.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def insert_vectors(
    self,
    collection_name: str,
    vectors: List[List[float]],
    payloads: Optional[List[Dict[str, Any]]] = None,
    ids: Optional[List[str]] = None
) -> None:
    """
    Insert vectors into a collection.

    Args:
        collection_name (str): The name of the collection.
        vectors (List[List[float]]): A list of vectors to insert.
        payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
        ids (Optional[List[str]]): Optional unique identifiers for each vector.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    if ids is None:
        ids = [str(i) for i in range(len(vectors))]
    if payloads is None:
        payloads = [{} for _ in range(len(vectors))]
    if not (len(ids) == len(vectors) == len(payloads)):
        raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

    self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
    logger.info(f"Inserted {len(vectors)} vectors into collection {collection_name}.")
list_collections()

List all available vector collections.

Returns:

Type Description
List[str]

List[str]: A list of collection names.

Source code in src/aeiva/storage/chroma/chroma_database.py
230
231
232
233
234
235
236
237
238
def list_collections(self) -> List[str]:
    """
    List all available vector collections.

    Returns:
        List[str]: A list of collection names.
    """
    collections = self.client.list_collections()
    return [collection.name for collection in collections]
search_vectors(collection_name, query_vector, top_k=5, filters=None)

Search for similar vectors in a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
query_vector List[float]

The vector to search with.

required
top_k int

The number of top results to return.

5
filters Optional[Dict[str, Any]]

Optional filters to apply to the search.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of search results.

Source code in src/aeiva/storage/chroma/chroma_database.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def search_vectors(
    self,
    collection_name: str,
    query_vector: List[float],
    top_k: int = 5,
    filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
    """
    Search for similar vectors in a collection.

    Args:
        collection_name (str): The name of the collection.
        query_vector (List[float]): The vector to search with.
        top_k (int): The number of top results to return.
        filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

    Returns:
        List[Dict[str, Any]]: A list of search results.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    results = self.collection.query(
        query_embeddings=[query_vector],
        where=filters,
        n_results=top_k
    )
    # Parse the results
    output = []
    for idx, (ids, distances, metadatas) in enumerate(zip(results['ids'], results['distances'], results['metadatas'])):
        for i in range(len(ids)):
            result = {
                'id': ids[i],
                'score': distances[i],
                'payload': metadatas[i]
            }
            output.append(result)
    return output
update_vector(collection_name, vector_id, vector=None, payload=None)

Update a vector's data or payload.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to update.

required
vector Optional[List[float]]

The new vector data.

None
payload Optional[Dict[str, Any]]

The new payload data.

None
Source code in src/aeiva/storage/chroma/chroma_database.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def update_vector(
    self,
    collection_name: str,
    vector_id: str,
    vector: Optional[List[float]] = None,
    payload: Optional[Dict[str, Any]] = None
) -> None:
    """
    Update a vector's data or payload.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to update.
        vector (Optional[List[float]]): The new vector data.
        payload (Optional[Dict[str, Any]]): The new payload data.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    self.collection.update(ids=[vector_id], embeddings=[vector] if vector else None, metadatas=[payload] if payload else None)
    logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

database_factory

DatabaseConfigFactory

Factory class to create database configuration objects based on the provider name.

Example

config = DatabaseConfigFactory.create( 'milvus', host='localhost', port=19530, embedding_model_dims=128, ... )

Source code in src/aeiva/storage/database_factory.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class DatabaseConfigFactory:
    """
    Factory class to create database configuration objects based on the provider name.

    Example:
        config = DatabaseConfigFactory.create(
            'milvus',
            host='localhost',
            port=19530,
            embedding_model_dims=128,
            ...
        )
    """

    provider_to_class = {
        "milvus": "aeiva.storage.milvus.milvus_config.MilvusConfig",
        "chroma": "aeiva.storage.chroma.chroma_config.ChromaConfig",
        "azure_ai_search": "aeiva.storage.azure_ai_search.azure_ai_search_config.AzureAISearchConfig",
        "pgvector": "aeiva.storage.pgvector.pgvector_config.PGVectorConfig",
        "qdrant": "aeiva.storage.qdrant.qdrant_config.QdrantConfig",
        "neo4j": "aeiva.storage.neo4jdb.neo4j_config.Neo4jConfig",
        "sqlite": "aeiva.storage.sqlite.sqlite_config.SQLiteConfig",
        "postgresql": "aeiva.storage.postgresql.postgresql_config.PostgreSQLConfig",
        "weaviate": "aeiva.storage.weaviate.weaviate_config.WeaviateConfig",
    }

    @classmethod
    def create(cls, provider_name: str, **kwargs) -> Any:
        """
        Create a database configuration object based on the provider name.

        Args:
            provider_name (str): The name of the database provider (e.g., 'milvus', 'chroma').
            **kwargs: Configuration parameters specific to the database provider.

        Returns:
            Any: An instance of the database configuration class.

        Raises:
            ValueError: If the provider name is not supported.
            ImportError: If the configuration class cannot be imported.
        """
        class_path = cls.provider_to_class.get(provider_name.lower())
        if class_path:
            config_class = load_class(class_path)
            return config_class(**kwargs)
        else:
            raise ValueError(f"Unsupported database provider: {provider_name}")
create(provider_name, **kwargs) classmethod

Create a database configuration object based on the provider name.

Parameters:

Name Type Description Default
provider_name str

The name of the database provider (e.g., 'milvus', 'chroma').

required
**kwargs

Configuration parameters specific to the database provider.

{}

Returns:

Name Type Description
Any Any

An instance of the database configuration class.

Raises:

Type Description
ValueError

If the provider name is not supported.

ImportError

If the configuration class cannot be imported.

Source code in src/aeiva/storage/database_factory.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@classmethod
def create(cls, provider_name: str, **kwargs) -> Any:
    """
    Create a database configuration object based on the provider name.

    Args:
        provider_name (str): The name of the database provider (e.g., 'milvus', 'chroma').
        **kwargs: Configuration parameters specific to the database provider.

    Returns:
        Any: An instance of the database configuration class.

    Raises:
        ValueError: If the provider name is not supported.
        ImportError: If the configuration class cannot be imported.
    """
    class_path = cls.provider_to_class.get(provider_name.lower())
    if class_path:
        config_class = load_class(class_path)
        return config_class(**kwargs)
    else:
        raise ValueError(f"Unsupported database provider: {provider_name}")

DatabaseFactory

Factory class to create database objects based on the provider name and configuration.

Example

db = DatabaseFactory.create('milvus', config)

Source code in src/aeiva/storage/database_factory.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class DatabaseFactory:
    """
    Factory class to create database objects based on the provider name and configuration.

    Example:
        db = DatabaseFactory.create('milvus', config)
    """

    provider_to_class = {
        "milvus": "aeiva.storage.milvus.milvus_database.MilvusDatabase",
        "chroma": "aeiva.storage.chroma.chroma_database.ChromaDatabase",
        "azure_ai_search": "aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase",
        "pgvector": "aeiva.storage.pgvector.pgvector_database.PGVectorDatabase",
        "qdrant": "aeiva.storage.qdrant.qdrant_database.QdrantDatabase",
        "neo4j": "aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase",
        "sqlite": "aeiva.storage.sqlite.sqlite_database.SQLiteDatabase",
        "postgresql": "aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase",
        "weaviate": "aeiva.storage.weaviate.weaviate_database.WeaviateDatabase",
    }

    @classmethod
    def create(cls, provider_name: str, config: Any) -> Any:
        """
        Create a database object based on the provider name and configuration.

        Args:
            provider_name (str): The name of the database provider.
            config (Any): Configuration object or dictionary for the database.

        Returns:
            Any: An instance of the database class.

        Raises:
            ValueError: If the provider name is not supported.
            ImportError: If the database class cannot be imported.
            TypeError: If the configuration cannot be converted to a dictionary.
        """
        class_path = cls.provider_to_class.get(provider_name.lower())
        if class_path:
            db_class = load_class(class_path)
            if isinstance(config, dict):
                return db_class(config)
            elif hasattr(config, 'to_dict'):
                # Assuming config is a dataclass with a 'to_dict' method
                return db_class(config.to_dict())
            elif hasattr(config, '__dict__'):
                # If config is a dataclass without 'to_dict', use __dict__
                return db_class(config.__dict__)
            else:
                raise TypeError(
                    "Config must be a dict or an object with 'to_dict' or '__dict__' method."
                )
        else:
            raise ValueError(f"Unsupported database provider: {provider_name}")
create(provider_name, config) classmethod

Create a database object based on the provider name and configuration.

Parameters:

Name Type Description Default
provider_name str

The name of the database provider.

required
config Any

Configuration object or dictionary for the database.

required

Returns:

Name Type Description
Any Any

An instance of the database class.

Raises:

Type Description
ValueError

If the provider name is not supported.

ImportError

If the database class cannot be imported.

TypeError

If the configuration cannot be converted to a dictionary.

Source code in src/aeiva/storage/database_factory.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@classmethod
def create(cls, provider_name: str, config: Any) -> Any:
    """
    Create a database object based on the provider name and configuration.

    Args:
        provider_name (str): The name of the database provider.
        config (Any): Configuration object or dictionary for the database.

    Returns:
        Any: An instance of the database class.

    Raises:
        ValueError: If the provider name is not supported.
        ImportError: If the database class cannot be imported.
        TypeError: If the configuration cannot be converted to a dictionary.
    """
    class_path = cls.provider_to_class.get(provider_name.lower())
    if class_path:
        db_class = load_class(class_path)
        if isinstance(config, dict):
            return db_class(config)
        elif hasattr(config, 'to_dict'):
            # Assuming config is a dataclass with a 'to_dict' method
            return db_class(config.to_dict())
        elif hasattr(config, '__dict__'):
            # If config is a dataclass without 'to_dict', use __dict__
            return db_class(config.__dict__)
        else:
            raise TypeError(
                "Config must be a dict or an object with 'to_dict' or '__dict__' method."
            )
    else:
        raise ValueError(f"Unsupported database provider: {provider_name}")

load_class(class_path)

Dynamically load a class from a string.

Parameters:

Name Type Description Default
class_path str

The full path to the class, e.g., 'module.submodule.ClassName'.

required

Returns:

Name Type Description
Type Type

The class type.

Raises:

Type Description
ImportError

If the module or class cannot be found.

Source code in src/aeiva/storage/database_factory.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def load_class(class_path: str) -> Type:
    """
    Dynamically load a class from a string.

    Args:
        class_path (str): The full path to the class, e.g., 'module.submodule.ClassName'.

    Returns:
        Type: The class type.

    Raises:
        ImportError: If the module or class cannot be found.
    """
    try:
        module_path, class_name = class_path.rsplit('.', 1)
        module = importlib.import_module(module_path)
        return getattr(module, class_name)
    except (ImportError, AttributeError) as e:
        raise ImportError(f"Cannot import '{class_name}' from '{module_path}': {e}")

graph_database

GraphDatabase

Bases: ABC

Abstract base class for graph database operations.

Source code in src/aeiva/storage/graph_database.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class GraphDatabase(ABC):
    """
    Abstract base class for graph database operations.
    """

    @abstractmethod
    def add_node(
        self, 
        node_id: str, 
        properties: Optional[Dict[str, Any]] = None, 
        labels: Optional[List[str]] = None
    ) -> None:
        """
        Adds a node to the graph.

        Args:
            node_id (str): Unique identifier for the node.
            properties (Optional[Dict[str, Any]]): Properties associated with the node.
            labels (Optional[List[str]]): Labels or types associated with the node.

        Raises:
            StorageError: If there is an issue adding the node.
        """
        pass

    @abstractmethod
    def add_edge(
        self, 
        source_id: str, 
        target_id: str, 
        relationship: str, 
        properties: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Adds an edge (relationship) between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship.
            properties (Optional[Dict[str, Any]]): Properties associated with the edge.

        Raises:
            NodeNotFoundError: If either the source or target node does not exist.
            StorageError: If there is an issue adding the edge.
        """
        pass

    @abstractmethod
    def get_node(self, node_id: str) -> Dict[str, Any]:
        """
        Retrieves a node by its identifier.

        Args:
            node_id (str): Unique identifier of the node.

        Returns:
            Dict[str, Any]: A dictionary containing the node's properties and labels.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue retrieving the node.
        """
        pass

    @abstractmethod
    def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:
        """
        Updates properties of a node.

        Args:
            node_id (str): Unique identifier of the node.
            properties (Dict[str, Any]): Properties to update.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue updating the node.
        """
        pass

    @abstractmethod
    def delete_node(self, node_id: str) -> None:
        """
        Deletes a node from the graph.

        Args:
            node_id (str): Unique identifier of the node.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue deleting the node.
        """
        pass

    @abstractmethod
    def delete_all(self) -> None:
        """
        Deletes all nodes and their associated relationships from the graph.

        Raises:
            StorageError: If there is an issue deleting all nodes and relationships.
        """
        pass

    @abstractmethod
    def delete_all_edges(self) -> None:
        """
        Deletes all edges from the graph without deleting the nodes.

        Raises:
            StorageError: If there is an issue deleting all relationships.
        """
        pass

    @abstractmethod
    def delete_relationships_by_type(self, relationship: str) -> None:
        """
        Deletes all relationships of a specific type from the graph.

        Args:
            relationship (str): The type of relationships to delete.

        Raises:
            StorageError: If there is an issue deleting the relationships.
        """
        pass

    @abstractmethod
    def delete_edge(
        self, 
        source_id: str, 
        target_id: str, 
        relationship: str
    ) -> None:
        """
        Deletes a specific relationship between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship to delete.

        Raises:
            RelationshipNotFoundError: If the relationship does not exist.
            StorageError: If there is an issue deleting the relationship.
        """
        pass

    @abstractmethod
    def update_edge(
        self, 
        source_id: str, 
        target_id: str, 
        relationship: str, 
        properties: Dict[str, Any]
    ) -> None:
        """
        Updates properties of a specific relationship between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship to update.
            properties (Dict[str, Any]): Properties to update on the relationship.

        Raises:
            RelationshipNotFoundError: If the relationship does not exist.
            StorageError: If there is an issue updating the relationship.
        """
        pass

    @abstractmethod
    def get_relationship(
        self, 
        source_id: str, 
        target_id: str, 
        relationship: str
    ) -> Dict[str, Any]:
        """
        Retrieves a specific relationship between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship to retrieve.

        Returns:
            Dict[str, Any]: A dictionary containing the relationship's properties.

        Raises:
            RelationshipNotFoundError: If the relationship does not exist.
            StorageError: If there is an issue retrieving the relationship.
        """
        pass

    @abstractmethod
    def get_neighbors(
        self, 
        node_id: str, 
        relationship: Optional[str] = None, 
        direction: str = "both"
    ) -> List[Dict[str, Any]]:
        """
        Retrieves neighboring nodes connected by edges.

        Args:
            node_id (str): Unique identifier of the node.
            relationship (Optional[str]): Filter by relationship type.
            direction (str): Direction of the relationships ('in', 'out', 'both').

        Returns:
            List[Dict[str, Any]]: A list of neighboring nodes.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue retrieving neighbors.
        """
        pass

    @abstractmethod
    def query_nodes(
        self, 
        properties: Dict[str, Any], 
        labels: Optional[List[str]] = None
    ) -> List[Dict[str, Any]]:
        """
        Queries nodes based on properties and labels.

        Args:
            properties (Dict[str, Any]): Properties to filter nodes.
            labels (Optional[List[str]]): Labels to filter nodes.

        Returns:
            List[Dict[str, Any]]: A list of nodes matching the query.

        Raises:
            StorageError: If there is an issue querying nodes.
        """
        pass

    @abstractmethod
    def execute_query(
        self, 
        query: str, 
        parameters: Optional[Dict[str, Any]] = None
    ) -> Any:
        """
        Executes a raw query against the graph database.

        Args:
            query (str): The query string.
            parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

        Returns:
            Any: The result of the query.

        Raises:
            StorageError: If there is an issue executing the query.
        """
        pass

    @abstractmethod
    def close(self) -> None:
        """
        Closes the graph database connection and releases resources.
        """
        pass
add_edge(source_id, target_id, relationship, properties=None) abstractmethod

Adds an edge (relationship) between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship.

required
properties Optional[Dict[str, Any]]

Properties associated with the edge.

None

Raises:

Type Description
NodeNotFoundError

If either the source or target node does not exist.

StorageError

If there is an issue adding the edge.

Source code in src/aeiva/storage/graph_database.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@abstractmethod
def add_edge(
    self, 
    source_id: str, 
    target_id: str, 
    relationship: str, 
    properties: Optional[Dict[str, Any]] = None
) -> None:
    """
    Adds an edge (relationship) between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship.
        properties (Optional[Dict[str, Any]]): Properties associated with the edge.

    Raises:
        NodeNotFoundError: If either the source or target node does not exist.
        StorageError: If there is an issue adding the edge.
    """
    pass
add_node(node_id, properties=None, labels=None) abstractmethod

Adds a node to the graph.

Parameters:

Name Type Description Default
node_id str

Unique identifier for the node.

required
properties Optional[Dict[str, Any]]

Properties associated with the node.

None
labels Optional[List[str]]

Labels or types associated with the node.

None

Raises:

Type Description
StorageError

If there is an issue adding the node.

Source code in src/aeiva/storage/graph_database.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@abstractmethod
def add_node(
    self, 
    node_id: str, 
    properties: Optional[Dict[str, Any]] = None, 
    labels: Optional[List[str]] = None
) -> None:
    """
    Adds a node to the graph.

    Args:
        node_id (str): Unique identifier for the node.
        properties (Optional[Dict[str, Any]]): Properties associated with the node.
        labels (Optional[List[str]]): Labels or types associated with the node.

    Raises:
        StorageError: If there is an issue adding the node.
    """
    pass
close() abstractmethod

Closes the graph database connection and releases resources.

Source code in src/aeiva/storage/graph_database.py
281
282
283
284
285
286
@abstractmethod
def close(self) -> None:
    """
    Closes the graph database connection and releases resources.
    """
    pass
delete_all() abstractmethod

Deletes all nodes and their associated relationships from the graph.

Raises:

Type Description
StorageError

If there is an issue deleting all nodes and relationships.

Source code in src/aeiva/storage/graph_database.py
114
115
116
117
118
119
120
121
122
@abstractmethod
def delete_all(self) -> None:
    """
    Deletes all nodes and their associated relationships from the graph.

    Raises:
        StorageError: If there is an issue deleting all nodes and relationships.
    """
    pass
delete_all_edges() abstractmethod

Deletes all edges from the graph without deleting the nodes.

Raises:

Type Description
StorageError

If there is an issue deleting all relationships.

Source code in src/aeiva/storage/graph_database.py
124
125
126
127
128
129
130
131
132
@abstractmethod
def delete_all_edges(self) -> None:
    """
    Deletes all edges from the graph without deleting the nodes.

    Raises:
        StorageError: If there is an issue deleting all relationships.
    """
    pass
delete_edge(source_id, target_id, relationship) abstractmethod

Deletes a specific relationship between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship to delete.

required

Raises:

Type Description
RelationshipNotFoundError

If the relationship does not exist.

StorageError

If there is an issue deleting the relationship.

Source code in src/aeiva/storage/graph_database.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@abstractmethod
def delete_edge(
    self, 
    source_id: str, 
    target_id: str, 
    relationship: str
) -> None:
    """
    Deletes a specific relationship between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship to delete.

    Raises:
        RelationshipNotFoundError: If the relationship does not exist.
        StorageError: If there is an issue deleting the relationship.
    """
    pass
delete_node(node_id) abstractmethod

Deletes a node from the graph.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue deleting the node.

Source code in src/aeiva/storage/graph_database.py
100
101
102
103
104
105
106
107
108
109
110
111
112
@abstractmethod
def delete_node(self, node_id: str) -> None:
    """
    Deletes a node from the graph.

    Args:
        node_id (str): Unique identifier of the node.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue deleting the node.
    """
    pass
delete_relationships_by_type(relationship) abstractmethod

Deletes all relationships of a specific type from the graph.

Parameters:

Name Type Description Default
relationship str

The type of relationships to delete.

required

Raises:

Type Description
StorageError

If there is an issue deleting the relationships.

Source code in src/aeiva/storage/graph_database.py
134
135
136
137
138
139
140
141
142
143
144
145
@abstractmethod
def delete_relationships_by_type(self, relationship: str) -> None:
    """
    Deletes all relationships of a specific type from the graph.

    Args:
        relationship (str): The type of relationships to delete.

    Raises:
        StorageError: If there is an issue deleting the relationships.
    """
    pass
execute_query(query, parameters=None) abstractmethod

Executes a raw query against the graph database.

Parameters:

Name Type Description Default
query str

The query string.

required
parameters Optional[Dict[str, Any]]

Parameters for parameterized queries.

None

Returns:

Name Type Description
Any Any

The result of the query.

Raises:

Type Description
StorageError

If there is an issue executing the query.

Source code in src/aeiva/storage/graph_database.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@abstractmethod
def execute_query(
    self, 
    query: str, 
    parameters: Optional[Dict[str, Any]] = None
) -> Any:
    """
    Executes a raw query against the graph database.

    Args:
        query (str): The query string.
        parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

    Returns:
        Any: The result of the query.

    Raises:
        StorageError: If there is an issue executing the query.
    """
    pass
get_neighbors(node_id, relationship=None, direction='both') abstractmethod

Retrieves neighboring nodes connected by edges.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required
relationship Optional[str]

Filter by relationship type.

None
direction str

Direction of the relationships ('in', 'out', 'both').

'both'

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of neighboring nodes.

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue retrieving neighbors.

Source code in src/aeiva/storage/graph_database.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
@abstractmethod
def get_neighbors(
    self, 
    node_id: str, 
    relationship: Optional[str] = None, 
    direction: str = "both"
) -> List[Dict[str, Any]]:
    """
    Retrieves neighboring nodes connected by edges.

    Args:
        node_id (str): Unique identifier of the node.
        relationship (Optional[str]): Filter by relationship type.
        direction (str): Direction of the relationships ('in', 'out', 'both').

    Returns:
        List[Dict[str, Any]]: A list of neighboring nodes.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue retrieving neighbors.
    """
    pass
get_node(node_id) abstractmethod

Retrieves a node by its identifier.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the node's properties and labels.

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue retrieving the node.

Source code in src/aeiva/storage/graph_database.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@abstractmethod
def get_node(self, node_id: str) -> Dict[str, Any]:
    """
    Retrieves a node by its identifier.

    Args:
        node_id (str): Unique identifier of the node.

    Returns:
        Dict[str, Any]: A dictionary containing the node's properties and labels.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue retrieving the node.
    """
    pass
get_relationship(source_id, target_id, relationship) abstractmethod

Retrieves a specific relationship between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship to retrieve.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the relationship's properties.

Raises:

Type Description
RelationshipNotFoundError

If the relationship does not exist.

StorageError

If there is an issue retrieving the relationship.

Source code in src/aeiva/storage/graph_database.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@abstractmethod
def get_relationship(
    self, 
    source_id: str, 
    target_id: str, 
    relationship: str
) -> Dict[str, Any]:
    """
    Retrieves a specific relationship between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship to retrieve.

    Returns:
        Dict[str, Any]: A dictionary containing the relationship's properties.

    Raises:
        RelationshipNotFoundError: If the relationship does not exist.
        StorageError: If there is an issue retrieving the relationship.
    """
    pass
query_nodes(properties, labels=None) abstractmethod

Queries nodes based on properties and labels.

Parameters:

Name Type Description Default
properties Dict[str, Any]

Properties to filter nodes.

required
labels Optional[List[str]]

Labels to filter nodes.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of nodes matching the query.

Raises:

Type Description
StorageError

If there is an issue querying nodes.

Source code in src/aeiva/storage/graph_database.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
@abstractmethod
def query_nodes(
    self, 
    properties: Dict[str, Any], 
    labels: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
    """
    Queries nodes based on properties and labels.

    Args:
        properties (Dict[str, Any]): Properties to filter nodes.
        labels (Optional[List[str]]): Labels to filter nodes.

    Returns:
        List[Dict[str, Any]]: A list of nodes matching the query.

    Raises:
        StorageError: If there is an issue querying nodes.
    """
    pass
update_edge(source_id, target_id, relationship, properties) abstractmethod

Updates properties of a specific relationship between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship to update.

required
properties Dict[str, Any]

Properties to update on the relationship.

required

Raises:

Type Description
RelationshipNotFoundError

If the relationship does not exist.

StorageError

If there is an issue updating the relationship.

Source code in src/aeiva/storage/graph_database.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@abstractmethod
def update_edge(
    self, 
    source_id: str, 
    target_id: str, 
    relationship: str, 
    properties: Dict[str, Any]
) -> None:
    """
    Updates properties of a specific relationship between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship to update.
        properties (Dict[str, Any]): Properties to update on the relationship.

    Raises:
        RelationshipNotFoundError: If the relationship does not exist.
        StorageError: If there is an issue updating the relationship.
    """
    pass
update_node(node_id, properties) abstractmethod

Updates properties of a node.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required
properties Dict[str, Any]

Properties to update.

required

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue updating the node.

Source code in src/aeiva/storage/graph_database.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@abstractmethod
def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:
    """
    Updates properties of a node.

    Args:
        node_id (str): Unique identifier of the node.
        properties (Dict[str, Any]): Properties to update.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue updating the node.
    """
    pass

NodeNotFoundError

Bases: Exception

Exception raised when a node is not found in the graph database.

Source code in src/aeiva/storage/graph_database.py
5
6
7
class NodeNotFoundError(Exception):
    """Exception raised when a node is not found in the graph database."""
    pass

RelationshipNotFoundError

Bases: Exception

Exception raised when a relationship is not found in the graph database.

Source code in src/aeiva/storage/graph_database.py
10
11
12
class RelationshipNotFoundError(Exception):
    """Exception raised when a relationship is not found in the graph database."""
    pass

StorageError

Bases: Exception

Exception raised when there is a storage-related error in the graph database.

Source code in src/aeiva/storage/graph_database.py
15
16
17
class StorageError(Exception):
    """Exception raised when there is a storage-related error in the graph database."""
    pass

milvus

milvus_config

MilvusConfig dataclass

Bases: BaseConfig

Configuration for Milvus vector database.

Source code in src/aeiva/storage/milvus/milvus_config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@dataclass
class MilvusConfig(BaseConfig):
    """
    Configuration for Milvus vector database.
    """

    uri: str = field(
        default="http://localhost:19530",
        metadata={"help": "Full URL for Milvus server."}
    )
    token: Optional[str] = field(
        default=None,
        metadata={"help": "Token for Milvus server authentication (if required)."}
    )
    collection_name: str = field(
        default="mem0",
        metadata={"help": "Name of the collection."}
    )
    embedding_model_dims: int = field(
        default=1536,
        metadata={"help": "Dimensions of the embedding model."}
    )
    metric_type: str = field(
        default="L2",
        metadata={"help": "Metric type for similarity search (e.g., 'L2', 'IP', 'COSINE')."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Validate metric_type
        valid_metrics = {"L2", "IP", "COSINE", "HAMMING", "JACCARD"}
        if self.metric_type not in valid_metrics:
            raise ValueError(f"Invalid metric_type '{self.metric_type}'. Valid options are {valid_metrics}.")

milvus_database

MilvusDatabase

Bases: VectorDatabase

Concrete implementation of VectorStoreBase using Milvus.

Source code in src/aeiva/storage/milvus/milvus_database.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
class MilvusDatabase(VectorDatabase):
    """
    Concrete implementation of VectorStoreBase using Milvus.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the Milvus vector store.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.collection_name = config.get('collection_name')
        self.uri = config.get('uri')
        self.user = config.get('user')
        self.password = config.get('password')
        self.token = config.get('token')
        self.embedding_model_dims = config.get('embedding_model_dims')
        self.metric_type = config.get('metric_type', 'L2')  # Default to 'L2' metric

        if not all([self.collection_name, self.uri, self.embedding_model_dims]):
            raise ValueError("Required configuration parameters are missing.")

        self.create_client(
            uri=self.uri,
            user=self.user,
            password=self.password,
            token=self.token
        )
        self.create_collection(
            collection_name=self.collection_name,
            vector_size=self.embedding_model_dims,
            distance_metric=self.metric_type
        )

    def create_client(
        self,
        uri: str,
        user: Optional[str] = None,
        password: Optional[str] = None,
        token: Optional[str] = None,
        **kwargs
    ) -> None:
        """
        Initializes the client connection to the Milvus vector store.

        Args:
            uri (str): The URI of the vector store instance.
            user (Optional[str]): Username for authentication.
            password (Optional[str]): Password for authentication.
            token (Optional[str]): Access token for authentication.
            **kwargs: Additional parameters.
        """
        try:
            connections.connect(
                alias="default",
                uri=uri,
                user=user,
                password=password,
                token=token,
                **kwargs
            )
            logger.info(f"Connected to Milvus at {uri}.")
        except MilvusException as e:
            logger.error(f"Failed to connect to Milvus: {e}")
            raise ConnectionError(f"Failed to connect to Milvus: {e}")

    def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
        """
        Create a new vector collection in Milvus.

        Args:
            collection_name (str): The name of the collection.
            vector_size (int): The dimensionality of the vectors.
            distance_metric (str): The distance metric to use (e.g., 'L2', 'IP', 'COSINE').
        """
        if utility.has_collection(collection_name):
            logger.info(f"Collection {collection_name} already exists. Skipping creation.")
            self.collection = Collection(collection_name)
            return

        # Define the schema
        fields = [
            FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=64),
            FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
            FieldSchema(name="payload", dtype=DataType.JSON)
        ]
        schema = CollectionSchema(fields=fields, description="Milvus Vector Store Collection")

        # Create the collection
        self.collection = Collection(name=collection_name, schema=schema)
        logger.info(f"Collection {collection_name} created successfully.")

        # Create index
        index_params = {
            "metric_type": distance_metric,
            "index_type": "AUTOINDEX",
            "params": {}
        }
        self.collection.create_index(field_name="vector", index_params=index_params)
        logger.info(f"Index created on collection {collection_name}.")

    def insert_vectors(
        self,
        collection_name: str,
        vectors: List[List[float]],
        payloads: Optional[List[Dict[str, Any]]] = None,
        ids: Optional[List[str]] = None
    ) -> None:
        """
        Insert vectors into a collection.

        Args:
            collection_name (str): The name of the collection.
            vectors (List[List[float]]): A list of vectors to insert.
            payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
            ids (Optional[List[str]]): Optional unique identifiers for each vector.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        if ids is None:
            raise ValueError("Milvus requires IDs to be provided for each vector.")
        if payloads is None:
            payloads = [{} for _ in range(len(vectors))]
        if not (len(ids) == len(vectors) == len(payloads)):
            raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

        data = [
            ids,
            vectors,
            payloads
        ]
        self.collection.insert(data)
        logger.info(f"Inserted {len(vectors)} vectors into collection {collection_name}.")

    def search_vectors(
        self,
        collection_name: str,
        query_vector: List[float],
        top_k: int = 5,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """
        Search for similar vectors in a collection.

        Args:
            collection_name (str): The name of the collection.
            query_vector (List[float]): The vector to search with.
            top_k (int): The number of top results to return.
            filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

        Returns:
            List[Dict[str, Any]]: A list of search results.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        search_params = {
            "metric_type": self.metric_type,
            "params": {}
        }

        expr = self._build_filter_expression(filters)
        results = self.collection.search(
            data=[query_vector],
            anns_field="vector",
            param=search_params,
            limit=top_k,
            expr=expr,
            output_fields=["id", "payload"]
        )

        output = []
        for hits in results:
            for hit in hits:
                result = {
                    'id': hit.entity.get('id'),
                    'score': hit.distance,
                    'payload': hit.entity.get('payload')
                }
                output.append(result)
        return output

    def _build_filter_expression(self, filters: Optional[Dict[str, Any]]) -> str:
        """
        Build an expression string for filtering in Milvus.

        Args:
            filters (Optional[Dict[str, Any]]): Filters to apply.

        Returns:
            str: The expression string.
        """
        if not filters:
            return ""

        expressions = []
        for key, value in filters.items():
            if isinstance(value, str):
                expressions.append(f'payload["{key}"] == "{value}"')
            else:
                expressions.append(f'payload["{key}"] == {value}')
        expr = " and ".join(expressions)
        return expr

    def delete_vector(self, collection_name: str, vector_id: str) -> None:
        """
        Delete a vector from a collection by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to delete.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        expr = f'id == "{vector_id}"'
        self.collection.delete(expr)
        logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")

    def update_vector(
        self,
        collection_name: str,
        vector_id: str,
        vector: Optional[List[float]] = None,
        payload: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Update a vector's data or payload.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to update.
            vector (Optional[List[float]]): The new vector data.
            payload (Optional[Dict[str, Any]]): The new payload data.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        # Milvus doesn't support direct updates; need to delete and re-insert
        # Fetch existing vector and payload
        expr = f'id == "{vector_id}"'
        results = self.collection.query(expr=expr, output_fields=["vector", "payload"])

        if not results:
            raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

        existing_vector = results[0]['vector']
        existing_payload = results[0]['payload']

        new_vector = vector if vector is not None else existing_vector
        new_payload = payload if payload is not None else existing_payload

        # Delete the existing vector
        self.collection.delete(expr)

        # Re-insert with updated data
        self.insert_vectors(
            collection_name=collection_name,
            vectors=[new_vector],
            payloads=[new_payload],
            ids=[vector_id]
        )
        logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

    def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
        """
        Retrieve a vector by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector.

        Returns:
            Dict[str, Any]: A dictionary containing the vector data and payload.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        expr = f'id == "{vector_id}"'
        results = self.collection.query(expr=expr, output_fields=["vector", "payload"])

        if not results:
            raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

        vector_data = {
            'id': vector_id,
            'vector': results[0]['vector'],
            'payload': results[0]['payload']
        }
        return vector_data

    def list_collections(self) -> List[str]:
        """
        List all available vector collections.

        Returns:
            List[str]: A list of collection names.
        """
        return utility.list_collections()

    def delete_collection(self, collection_name: str) -> None:
        """
        Delete an entire vector collection.

        Args:
            collection_name (str): The name of the collection to delete.
        """
        self.collection.drop()
        logger.info(f"Deleted collection {collection_name}.")

    def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        """
        Get information about a collection.

        Args:
            collection_name (str): The name of the collection.

        Returns:
            Dict[str, Any]: Information about the collection.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        info = self.collection.describe()
        return info

    def __del__(self):
        """Clean up resources."""
        connections.disconnect("default")
__del__()

Clean up resources.

Source code in src/aeiva/storage/milvus/milvus_database.py
350
351
352
def __del__(self):
    """Clean up resources."""
    connections.disconnect("default")
__init__(config)

Initialize the Milvus vector store.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/milvus/milvus_database.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the Milvus vector store.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.collection_name = config.get('collection_name')
    self.uri = config.get('uri')
    self.user = config.get('user')
    self.password = config.get('password')
    self.token = config.get('token')
    self.embedding_model_dims = config.get('embedding_model_dims')
    self.metric_type = config.get('metric_type', 'L2')  # Default to 'L2' metric

    if not all([self.collection_name, self.uri, self.embedding_model_dims]):
        raise ValueError("Required configuration parameters are missing.")

    self.create_client(
        uri=self.uri,
        user=self.user,
        password=self.password,
        token=self.token
    )
    self.create_collection(
        collection_name=self.collection_name,
        vector_size=self.embedding_model_dims,
        distance_metric=self.metric_type
    )
create_client(uri, user=None, password=None, token=None, **kwargs)

Initializes the client connection to the Milvus vector store.

Parameters:

Name Type Description Default
uri str

The URI of the vector store instance.

required
user Optional[str]

Username for authentication.

None
password Optional[str]

Password for authentication.

None
token Optional[str]

Access token for authentication.

None
**kwargs

Additional parameters.

{}
Source code in src/aeiva/storage/milvus/milvus_database.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def create_client(
    self,
    uri: str,
    user: Optional[str] = None,
    password: Optional[str] = None,
    token: Optional[str] = None,
    **kwargs
) -> None:
    """
    Initializes the client connection to the Milvus vector store.

    Args:
        uri (str): The URI of the vector store instance.
        user (Optional[str]): Username for authentication.
        password (Optional[str]): Password for authentication.
        token (Optional[str]): Access token for authentication.
        **kwargs: Additional parameters.
    """
    try:
        connections.connect(
            alias="default",
            uri=uri,
            user=user,
            password=password,
            token=token,
            **kwargs
        )
        logger.info(f"Connected to Milvus at {uri}.")
    except MilvusException as e:
        logger.error(f"Failed to connect to Milvus: {e}")
        raise ConnectionError(f"Failed to connect to Milvus: {e}")
create_collection(collection_name, vector_size, distance_metric)

Create a new vector collection in Milvus.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_size int

The dimensionality of the vectors.

required
distance_metric str

The distance metric to use (e.g., 'L2', 'IP', 'COSINE').

required
Source code in src/aeiva/storage/milvus/milvus_database.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
    """
    Create a new vector collection in Milvus.

    Args:
        collection_name (str): The name of the collection.
        vector_size (int): The dimensionality of the vectors.
        distance_metric (str): The distance metric to use (e.g., 'L2', 'IP', 'COSINE').
    """
    if utility.has_collection(collection_name):
        logger.info(f"Collection {collection_name} already exists. Skipping creation.")
        self.collection = Collection(collection_name)
        return

    # Define the schema
    fields = [
        FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=64),
        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
        FieldSchema(name="payload", dtype=DataType.JSON)
    ]
    schema = CollectionSchema(fields=fields, description="Milvus Vector Store Collection")

    # Create the collection
    self.collection = Collection(name=collection_name, schema=schema)
    logger.info(f"Collection {collection_name} created successfully.")

    # Create index
    index_params = {
        "metric_type": distance_metric,
        "index_type": "AUTOINDEX",
        "params": {}
    }
    self.collection.create_index(field_name="vector", index_params=index_params)
    logger.info(f"Index created on collection {collection_name}.")
delete_collection(collection_name)

Delete an entire vector collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection to delete.

required
Source code in src/aeiva/storage/milvus/milvus_database.py
324
325
326
327
328
329
330
331
332
def delete_collection(self, collection_name: str) -> None:
    """
    Delete an entire vector collection.

    Args:
        collection_name (str): The name of the collection to delete.
    """
    self.collection.drop()
    logger.info(f"Deleted collection {collection_name}.")
delete_vector(collection_name, vector_id)

Delete a vector from a collection by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to delete.

required
Source code in src/aeiva/storage/milvus/milvus_database.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def delete_vector(self, collection_name: str, vector_id: str) -> None:
    """
    Delete a vector from a collection by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to delete.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    expr = f'id == "{vector_id}"'
    self.collection.delete(expr)
    logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")
get_collection_info(collection_name)

Get information about a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Information about the collection.

Source code in src/aeiva/storage/milvus/milvus_database.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
    """
    Get information about a collection.

    Args:
        collection_name (str): The name of the collection.

    Returns:
        Dict[str, Any]: Information about the collection.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    info = self.collection.describe()
    return info
get_vector(collection_name, vector_id)

Retrieve a vector by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the vector data and payload.

Source code in src/aeiva/storage/milvus/milvus_database.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
    """
    Retrieve a vector by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector.

    Returns:
        Dict[str, Any]: A dictionary containing the vector data and payload.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    expr = f'id == "{vector_id}"'
    results = self.collection.query(expr=expr, output_fields=["vector", "payload"])

    if not results:
        raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

    vector_data = {
        'id': vector_id,
        'vector': results[0]['vector'],
        'payload': results[0]['payload']
    }
    return vector_data
insert_vectors(collection_name, vectors, payloads=None, ids=None)

Insert vectors into a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vectors List[List[float]]

A list of vectors to insert.

required
payloads Optional[List[Dict[str, Any]]]

Optional metadata associated with each vector.

None
ids Optional[List[str]]

Optional unique identifiers for each vector.

None
Source code in src/aeiva/storage/milvus/milvus_database.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def insert_vectors(
    self,
    collection_name: str,
    vectors: List[List[float]],
    payloads: Optional[List[Dict[str, Any]]] = None,
    ids: Optional[List[str]] = None
) -> None:
    """
    Insert vectors into a collection.

    Args:
        collection_name (str): The name of the collection.
        vectors (List[List[float]]): A list of vectors to insert.
        payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
        ids (Optional[List[str]]): Optional unique identifiers for each vector.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    if ids is None:
        raise ValueError("Milvus requires IDs to be provided for each vector.")
    if payloads is None:
        payloads = [{} for _ in range(len(vectors))]
    if not (len(ids) == len(vectors) == len(payloads)):
        raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

    data = [
        ids,
        vectors,
        payloads
    ]
    self.collection.insert(data)
    logger.info(f"Inserted {len(vectors)} vectors into collection {collection_name}.")
list_collections()

List all available vector collections.

Returns:

Type Description
List[str]

List[str]: A list of collection names.

Source code in src/aeiva/storage/milvus/milvus_database.py
315
316
317
318
319
320
321
322
def list_collections(self) -> List[str]:
    """
    List all available vector collections.

    Returns:
        List[str]: A list of collection names.
    """
    return utility.list_collections()
search_vectors(collection_name, query_vector, top_k=5, filters=None)

Search for similar vectors in a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
query_vector List[float]

The vector to search with.

required
top_k int

The number of top results to return.

5
filters Optional[Dict[str, Any]]

Optional filters to apply to the search.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of search results.

Source code in src/aeiva/storage/milvus/milvus_database.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def search_vectors(
    self,
    collection_name: str,
    query_vector: List[float],
    top_k: int = 5,
    filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
    """
    Search for similar vectors in a collection.

    Args:
        collection_name (str): The name of the collection.
        query_vector (List[float]): The vector to search with.
        top_k (int): The number of top results to return.
        filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

    Returns:
        List[Dict[str, Any]]: A list of search results.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    search_params = {
        "metric_type": self.metric_type,
        "params": {}
    }

    expr = self._build_filter_expression(filters)
    results = self.collection.search(
        data=[query_vector],
        anns_field="vector",
        param=search_params,
        limit=top_k,
        expr=expr,
        output_fields=["id", "payload"]
    )

    output = []
    for hits in results:
        for hit in hits:
            result = {
                'id': hit.entity.get('id'),
                'score': hit.distance,
                'payload': hit.entity.get('payload')
            }
            output.append(result)
    return output
update_vector(collection_name, vector_id, vector=None, payload=None)

Update a vector's data or payload.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to update.

required
vector Optional[List[float]]

The new vector data.

None
payload Optional[Dict[str, Any]]

The new payload data.

None
Source code in src/aeiva/storage/milvus/milvus_database.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def update_vector(
    self,
    collection_name: str,
    vector_id: str,
    vector: Optional[List[float]] = None,
    payload: Optional[Dict[str, Any]] = None
) -> None:
    """
    Update a vector's data or payload.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to update.
        vector (Optional[List[float]]): The new vector data.
        payload (Optional[Dict[str, Any]]): The new payload data.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    # Milvus doesn't support direct updates; need to delete and re-insert
    # Fetch existing vector and payload
    expr = f'id == "{vector_id}"'
    results = self.collection.query(expr=expr, output_fields=["vector", "payload"])

    if not results:
        raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

    existing_vector = results[0]['vector']
    existing_payload = results[0]['payload']

    new_vector = vector if vector is not None else existing_vector
    new_payload = payload if payload is not None else existing_payload

    # Delete the existing vector
    self.collection.delete(expr)

    # Re-insert with updated data
    self.insert_vectors(
        collection_name=collection_name,
        vectors=[new_vector],
        payloads=[new_payload],
        ids=[vector_id]
    )
    logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

neo4jdb

neo4j_config

Neo4jConfig dataclass

Bases: BaseConfig

Configuration for Neo4j graph database.

Source code in src/aeiva/storage/neo4jdb/neo4j_config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass
class Neo4jConfig(BaseConfig):
    """
    Configuration for Neo4j graph database.
    """

    uri: str = field(
        default="bolt://localhost:7687",
        metadata={"help": "URI for connecting to Neo4j (e.g., 'bolt://localhost:7687')."}
    )
    user: Optional[str] = field(
        default=None,
        metadata={"help": "Username for Neo4j authentication."}
    )
    password: Optional[str] = field(
        default=None,
        metadata={"help": "Password for Neo4j authentication."}
    )
    database: Optional[str] = field(
        default="neo4j",
        metadata={"help": "Neo4j database name."}
    )
    encrypted: bool = field(
        default=True,
        metadata={"help": "Whether to use encrypted connection (True or False)."}
    )

    def __post_init__(self):
        super().__post_init__()
        if not self.user or not self.password:
            raise ValueError("Both 'user' and 'password' must be provided for Neo4j authentication.")

neo4j_database

Neo4jDatabase

Bases: GraphDatabase

Concrete implementation of GraphStoreBase using Neo4j.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
class Neo4jDatabase(GraphDatabase):
    """
    Concrete implementation of GraphStoreBase using Neo4j.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the Neo4j graph database connection.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.uri = config.get('uri')
        self.user = config.get('user')
        self.password = config.get('password')
        self.database = config.get('database', 'neo4j')
        self.encrypted = config.get('encrypted', True)

        if not all([self.uri, self.user, self.password]):
            raise ValueError("Required configuration parameters 'uri', 'user', and 'password' are missing.")

        self.create_client(
            uri=self.uri,
            user=self.user,
            password=self.password,
            encrypted=self.encrypted
        )

    def create_client(
        self,
        uri: str,
        user: str,
        password: str,
        encrypted: bool = True,
        **kwargs
    ) -> None:
        """
        Initializes the client connection to the Neo4j graph database.

        Args:
            uri (str): The URI of the Neo4j instance.
            user (str): Username for authentication.
            password (str): Password for authentication.
            encrypted (bool): Whether to use encrypted connection.
            **kwargs: Additional parameters.

        Raises:
            ConnectionError: If the client fails to connect to the graph database.
        """
        try:
            auth = basic_auth(user, password)
            self.driver = Neo4jGraphDatabase.driver(uri, auth=auth, encrypted=encrypted, **kwargs)
            self.session = self.driver.session(database=self.database)
            logger.info(f"Connected to Neo4j at {uri}.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to connect to Neo4j: {e}")
            raise ConnectionError(f"Failed to connect to Neo4j: {e}")

    def add_node(
        self,
        node_id: str,
        properties: Optional[Dict[str, Any]] = None,
        labels: Optional[List[str]] = None
    ) -> None:
        """
        Adds a node to the graph.

        Args:
            node_id (str): Unique identifier for the node.
            properties (Optional[Dict[str, Any]]): Properties associated with the node.
            labels (Optional[List[str]]): Labels or types associated with the node.

        Raises:
            StorageError: If there is an issue adding the node.
        """
        properties = properties or {}
        labels = labels or []
        labels_str = ':' + ':'.join(labels) if labels else ''
        cypher = f"MERGE (n{labels_str} {{id: $node_id}}) SET n += $properties"
        params = {
            'node_id': node_id,
            'properties': properties
        }
        try:
            self.session.run(cypher, params)
            logger.info(f"Node with id '{node_id}' added to the graph.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to add node: {e}")
            raise StorageError(f"Failed to add node: {e}")

    def add_edge(
        self,
        source_id: str,
        target_id: str,
        relationship: str,
        properties: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Adds an edge (relationship) between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship.
            properties (Optional[Dict[str, Any]]): Properties associated with the edge.

        Raises:
            NodeNotFoundError: If either the source or target node does not exist.
            StorageError: If there is an issue adding the edge.
        """
        properties = properties or {}
        # First, check if both nodes exist
        cypher_check = "MATCH (a {id: $source_id}), (b {id: $target_id}) RETURN a, b"
        params = {
            'source_id': source_id,
            'target_id': target_id
        }
        try:
            result = self.session.run(cypher_check, params)
            record = result.single()
            if not record:
                missing_nodes = []
                # Check if source node exists
                node_a_exists = self.session.run("MATCH (a {id: $source_id}) RETURN a", {'source_id': source_id}).single()
                if not node_a_exists:
                    missing_nodes.append(source_id)
                # Check if target node exists
                node_b_exists = self.session.run("MATCH (b {id: $target_id}) RETURN b", {'target_id': target_id}).single()
                if not node_b_exists:
                    missing_nodes.append(target_id)
                logger.warning(f"Node(s) with id(s) {missing_nodes} not found.")
                raise NodeNotFoundError(f"Node(s) with id(s) {missing_nodes} not found.")
            # Proceed to add the edge
            cypher_edge = (
                "MATCH (a {id: $source_id}), (b {id: $target_id}) "
                f"MERGE (a)-[r:{relationship}]->(b) "
                "SET r += $properties"
            )
            params['properties'] = properties
            self.session.run(cypher_edge, params)
            logger.info(f"Relationship '{relationship}' added between '{source_id}' and '{target_id}'.")
        except NodeNotFoundError:
            raise
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to add edge: {e}")
            raise StorageError(f"Failed to add edge: {e}")

    def get_node(self, node_id: str) -> Dict[str, Any]:
        """
        Retrieves a node by its identifier.

        Args:
            node_id (str): Unique identifier of the node.

        Returns:
            Dict[str, Any]: A dictionary containing the node's properties and labels.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue retrieving the node.
        """
        cypher = "MATCH (n {id: $node_id}) RETURN n"
        params = {'node_id': node_id}
        try:
            result = self.session.run(cypher, params)
            record = result.single()
            if record:
                node = record['n']
                node_data = {
                    'id': node['id'],
                    'properties': {k: v for k, v in node.items() if k != 'id'},
                    'labels': list(node.labels)
                }
                logger.info(f"Node with id '{node_id}' retrieved.")
                return node_data
            else:
                logger.warning(f"Node with id '{node_id}' not found.")
                raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to get node: {e}")
            raise StorageError(f"Failed to get node: {e}")

    def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:
        """
        Updates properties of a node.

        Args:
            node_id (str): Unique identifier of the node.
            properties (Dict[str, Any]): Properties to update.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue updating the node.
        """
        cypher = "MATCH (n {id: $node_id}) SET n += $properties RETURN n"
        params = {
            'node_id': node_id,
            'properties': properties
        }
        try:
            result = self.session.run(cypher, params)
            record = result.single()
            if record:
                logger.info(f"Node with id '{node_id}' updated.")
            else:
                logger.warning(f"Node with id '{node_id}' not found.")
                raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to update node: {e}")
            raise StorageError(f"Failed to update node: {e}")

    def delete_node(self, node_id: str) -> None:
        """
        Deletes a node from the graph.

        Args:
            node_id (str): Unique identifier of the node.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue deleting the node.
        """
        cypher = "MATCH (n {id: $node_id}) DETACH DELETE n RETURN COUNT(n) AS count"
        params = {'node_id': node_id}
        try:
            result = self.session.run(cypher, params)
            record = result.single()
            if record and record['count'] > 0:
                logger.info(f"Node with id '{node_id}' deleted.")
            else:
                logger.warning(f"Node with id '{node_id}' not found.")
                raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to delete node: {e}")
            raise StorageError(f"Failed to delete node: {e}")

    def delete_edge(
        self,
        source_id: str,
        target_id: str,
        relationship: str
    ) -> None:
        """
        Deletes a specific relationship between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship to delete.

        Raises:
            StorageError: If there is an issue deleting the relationship.
        """
        cypher = (
            "MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) "
            "DELETE r"
        ) % relationship
        params = {
            'source_id': source_id,
            'target_id': target_id
        }
        try:
            result = self.session.run(cypher, params)
            if result.consume().counters.relationships_deleted == 0:
                logger.warning(f"No relationship '{relationship}' found between '{source_id}' and '{target_id}'.")
                raise StorageError(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
            logger.info(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' deleted.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to delete relationship: {e}")
            raise StorageError(f"Failed to delete relationship: {e}")

    def update_edge(
        self,
        source_id: str,
        target_id: str,
        relationship: str,
        properties: Dict[str, Any]
    ) -> None:
        """
        Updates properties of a specific relationship between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship to update.
            properties (Dict[str, Any]): Properties to update on the relationship.

        Raises:
            StorageError: If there is an issue updating the relationship.
        """
        cypher = (
            "MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) "
            "SET r += $properties RETURN r"
        ) % relationship
        params = {
            'source_id': source_id,
            'target_id': target_id,
            'properties': properties
        }
        try:
            result = self.session.run(cypher, params)
            record = result.single()
            if record:
                logger.info(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' updated with properties {properties}.")
            else:
                logger.warning(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
                raise StorageError(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to update relationship: {e}")
            raise StorageError(f"Failed to update relationship: {e}")

    def get_relationship(
        self,
        source_id: str,
        target_id: str,
        relationship: str
    ) -> Dict[str, Any]:
        """
        Retrieves a specific relationship between two nodes.

        Args:
            source_id (str): Unique identifier of the source node.
            target_id (str): Unique identifier of the target node.
            relationship (str): Type of the relationship to retrieve.

        Returns:
            Dict[str, Any]: A dictionary containing the relationship's properties.

        Raises:
            StorageError: If there is an issue retrieving the relationship.
        """
        cypher = (
            "MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) "
            "RETURN r"
        ) % relationship
        params = {
            'source_id': source_id,
            'target_id': target_id
        }
        try:
            result = self.session.run(cypher, params)
            record = result.single()
            if record:
                relationship_data = record['r']
                properties = dict(relationship_data)
                properties['type'] = relationship.type  # Include relationship type
                logger.info(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' retrieved.")
                return properties
            else:
                logger.warning(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
                raise StorageError(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to retrieve relationship: {e}")
            raise StorageError(f"Failed to retrieve relationship: {e}")

    def delete_all_edges(self) -> None:
        """
        Deletes all relationships from the Neo4j graph database without deleting nodes.

        Raises:
            StorageError: If there is an issue deleting relationships.
        """
        cypher = "MATCH ()-[r]->() DELETE r"
        try:
            self.session.run(cypher)
            logger.info("All relationships have been deleted from Neo4j.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to delete all relationships: {e}")
            raise StorageError(f"Failed to delete all relationships: {e}")

    def delete_relationships_by_type(self, relationship: str) -> None:
        """
        Deletes all relationships of a specific type from the Neo4j graph database.

        Args:
            relationship (str): The type of relationships to delete.

        Raises:
            StorageError: If there is an issue deleting the relationships.
        """
        cypher = f"MATCH ()-[r:{relationship}]->() DELETE r"
        try:
            self.session.run(cypher)
            logger.info(f"All relationships of type '{relationship}' have been deleted from Neo4j.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to delete relationships of type '{relationship}': {e}")
            raise StorageError(f"Failed to delete relationships of type '{relationship}': {e}")

    def delete_all(self) -> None:
        """
        Deletes all nodes and relationships from the Neo4j graph database.

        Raises:
            StorageError: If there is an issue deleting all nodes and relationships.
        """
        cypher = "MATCH (n) DETACH DELETE n"
        try:
            self.session.run(cypher)
            logger.info("All nodes and relationships have been deleted from Neo4j.")
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to delete all nodes and relationships: {e}")
            raise StorageError(f"Failed to delete all nodes and relationships: {e}")

    def get_neighbors(
        self,
        node_id: str,
        relationship: Optional[str] = None,
        direction: str = "both"
    ) -> List[Dict[str, Any]]:
        """
        Retrieves neighboring nodes connected by edges.

        Args:
            node_id (str): Unique identifier of the node.
            relationship (Optional[str]): Filter by relationship type.
            direction (str): Direction of the relationships ('in', 'out', 'both').

        Returns:
            List[Dict[str, Any]]: A list of neighboring nodes.

        Raises:
            NodeNotFoundError: If the node does not exist.
            StorageError: If there is an issue retrieving neighbors.
        """
        if direction not in ["in", "out", "both"]:
            raise ValueError("Invalid direction. Must be 'in', 'out', or 'both'.")

        rel_type = f":{relationship}" if relationship else ''
        if direction == "in":
            pattern = f"<-[r{rel_type}]-"
        elif direction == "out":
            pattern = f"-[r{rel_type}]->"
        else:  # both
            pattern = f"-[r{rel_type}]-"

        cypher = f"MATCH (n {{id: $node_id}}){pattern}(neighbor) RETURN neighbor"
        params = {'node_id': node_id}
        try:
            # First, check if the node exists
            node_exists_query = "MATCH (n {id: $node_id}) RETURN n"
            node_result = self.session.run(node_exists_query, params)
            if not node_result.single():
                logger.warning(f"Node with id '{node_id}' not found.")
                raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
            # Get neighbors
            result = self.session.run(cypher, params)
            neighbors = []
            for record in result:
                node = record['neighbor']
                neighbor_data = {
                    'id': node['id'],
                    'properties': {k: v for k, v in node.items() if k != 'id'},
                    'labels': list(node.labels)
                }
                neighbors.append(neighbor_data)
            logger.info(f"Neighbors of node '{node_id}' retrieved.")
            return neighbors
        except NodeNotFoundError:
            raise
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to get neighbors: {e}")
            raise StorageError(f"Failed to get neighbors: {e}")

    def query_nodes(
        self,
        properties: Dict[str, Any],
        labels: Optional[List[str]] = None
    ) -> List[Dict[str, Any]]:
        """
        Queries nodes based on properties and labels.

        Args:
            properties (Dict[str, Any]): Properties to filter nodes.
            labels (Optional[List[str]]): Labels to filter nodes.

        Returns:
            List[Dict[str, Any]]: A list of nodes matching the query.

        Raises:
            StorageError: If there is an issue querying nodes.
        """
        labels_str = ':' + ':'.join(labels) if labels else ''
        params = {}
        cypher = f"MATCH (n{labels_str})"

        if properties:
            props_conditions = ' AND '.join([f"n.{key} = ${key}" for key in properties.keys()])
            cypher += f" WHERE {props_conditions}"
            params.update(properties)

        cypher += " RETURN n"

        try:
            result = self.session.run(cypher, params)
            nodes = []
            for record in result:
                node = record['n']
                node_data = {
                    'id': node['id'],
                    'properties': {k: v for k, v in node.items() if k != 'id'},
                    'labels': list(node.labels)
                }
                nodes.append(node_data)
            logger.info(f"Query returned {len(nodes)} nodes.")
            return nodes
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to query nodes: {e}")
            raise StorageError(f"Failed to query nodes: {e}")

    def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:
        """
        Executes a raw query against the graph database.

        Args:
            query (str): The query string.
            parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

        Returns:
            Any: The result of the query.

        Raises:
            StorageError: If there is an issue executing the query.
        """
        try:
            result = self.session.run(query, parameters)
            records = [record.data() for record in result]
            logger.info(f"Executed query: {query}")
            return records
        except exceptions.Neo4jError as e:
            logger.error(f"Failed to execute query: {e}")
            raise StorageError(f"Failed to execute query: {e}")

    def close(self) -> None:
        """
        Closes the graph database connection and releases resources.
        """
        if hasattr(self, 'session') and self.session:
            self.session.close()
        if hasattr(self, 'driver') and self.driver:
            self.driver.close()
        logger.info("Closed connection to Neo4j database.")

    def __del__(self):
        """Destructor to ensure resources are cleaned up."""
        self.close()
__del__()

Destructor to ensure resources are cleaned up.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
564
565
566
def __del__(self):
    """Destructor to ensure resources are cleaned up."""
    self.close()
__init__(config)

Initialize the Neo4j graph database connection.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the Neo4j graph database connection.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.uri = config.get('uri')
    self.user = config.get('user')
    self.password = config.get('password')
    self.database = config.get('database', 'neo4j')
    self.encrypted = config.get('encrypted', True)

    if not all([self.uri, self.user, self.password]):
        raise ValueError("Required configuration parameters 'uri', 'user', and 'password' are missing.")

    self.create_client(
        uri=self.uri,
        user=self.user,
        password=self.password,
        encrypted=self.encrypted
    )
add_edge(source_id, target_id, relationship, properties=None)

Adds an edge (relationship) between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship.

required
properties Optional[Dict[str, Any]]

Properties associated with the edge.

None

Raises:

Type Description
NodeNotFoundError

If either the source or target node does not exist.

StorageError

If there is an issue adding the edge.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def add_edge(
    self,
    source_id: str,
    target_id: str,
    relationship: str,
    properties: Optional[Dict[str, Any]] = None
) -> None:
    """
    Adds an edge (relationship) between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship.
        properties (Optional[Dict[str, Any]]): Properties associated with the edge.

    Raises:
        NodeNotFoundError: If either the source or target node does not exist.
        StorageError: If there is an issue adding the edge.
    """
    properties = properties or {}
    # First, check if both nodes exist
    cypher_check = "MATCH (a {id: $source_id}), (b {id: $target_id}) RETURN a, b"
    params = {
        'source_id': source_id,
        'target_id': target_id
    }
    try:
        result = self.session.run(cypher_check, params)
        record = result.single()
        if not record:
            missing_nodes = []
            # Check if source node exists
            node_a_exists = self.session.run("MATCH (a {id: $source_id}) RETURN a", {'source_id': source_id}).single()
            if not node_a_exists:
                missing_nodes.append(source_id)
            # Check if target node exists
            node_b_exists = self.session.run("MATCH (b {id: $target_id}) RETURN b", {'target_id': target_id}).single()
            if not node_b_exists:
                missing_nodes.append(target_id)
            logger.warning(f"Node(s) with id(s) {missing_nodes} not found.")
            raise NodeNotFoundError(f"Node(s) with id(s) {missing_nodes} not found.")
        # Proceed to add the edge
        cypher_edge = (
            "MATCH (a {id: $source_id}), (b {id: $target_id}) "
            f"MERGE (a)-[r:{relationship}]->(b) "
            "SET r += $properties"
        )
        params['properties'] = properties
        self.session.run(cypher_edge, params)
        logger.info(f"Relationship '{relationship}' added between '{source_id}' and '{target_id}'.")
    except NodeNotFoundError:
        raise
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to add edge: {e}")
        raise StorageError(f"Failed to add edge: {e}")
add_node(node_id, properties=None, labels=None)

Adds a node to the graph.

Parameters:

Name Type Description Default
node_id str

Unique identifier for the node.

required
properties Optional[Dict[str, Any]]

Properties associated with the node.

None
labels Optional[List[str]]

Labels or types associated with the node.

None

Raises:

Type Description
StorageError

If there is an issue adding the node.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def add_node(
    self,
    node_id: str,
    properties: Optional[Dict[str, Any]] = None,
    labels: Optional[List[str]] = None
) -> None:
    """
    Adds a node to the graph.

    Args:
        node_id (str): Unique identifier for the node.
        properties (Optional[Dict[str, Any]]): Properties associated with the node.
        labels (Optional[List[str]]): Labels or types associated with the node.

    Raises:
        StorageError: If there is an issue adding the node.
    """
    properties = properties or {}
    labels = labels or []
    labels_str = ':' + ':'.join(labels) if labels else ''
    cypher = f"MERGE (n{labels_str} {{id: $node_id}}) SET n += $properties"
    params = {
        'node_id': node_id,
        'properties': properties
    }
    try:
        self.session.run(cypher, params)
        logger.info(f"Node with id '{node_id}' added to the graph.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to add node: {e}")
        raise StorageError(f"Failed to add node: {e}")
close()

Closes the graph database connection and releases resources.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
554
555
556
557
558
559
560
561
562
def close(self) -> None:
    """
    Closes the graph database connection and releases resources.
    """
    if hasattr(self, 'session') and self.session:
        self.session.close()
    if hasattr(self, 'driver') and self.driver:
        self.driver.close()
    logger.info("Closed connection to Neo4j database.")
create_client(uri, user, password, encrypted=True, **kwargs)

Initializes the client connection to the Neo4j graph database.

Parameters:

Name Type Description Default
uri str

The URI of the Neo4j instance.

required
user str

Username for authentication.

required
password str

Password for authentication.

required
encrypted bool

Whether to use encrypted connection.

True
**kwargs

Additional parameters.

{}

Raises:

Type Description
ConnectionError

If the client fails to connect to the graph database.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def create_client(
    self,
    uri: str,
    user: str,
    password: str,
    encrypted: bool = True,
    **kwargs
) -> None:
    """
    Initializes the client connection to the Neo4j graph database.

    Args:
        uri (str): The URI of the Neo4j instance.
        user (str): Username for authentication.
        password (str): Password for authentication.
        encrypted (bool): Whether to use encrypted connection.
        **kwargs: Additional parameters.

    Raises:
        ConnectionError: If the client fails to connect to the graph database.
    """
    try:
        auth = basic_auth(user, password)
        self.driver = Neo4jGraphDatabase.driver(uri, auth=auth, encrypted=encrypted, **kwargs)
        self.session = self.driver.session(database=self.database)
        logger.info(f"Connected to Neo4j at {uri}.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to connect to Neo4j: {e}")
        raise ConnectionError(f"Failed to connect to Neo4j: {e}")
delete_all()

Deletes all nodes and relationships from the Neo4j graph database.

Raises:

Type Description
StorageError

If there is an issue deleting all nodes and relationships.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def delete_all(self) -> None:
    """
    Deletes all nodes and relationships from the Neo4j graph database.

    Raises:
        StorageError: If there is an issue deleting all nodes and relationships.
    """
    cypher = "MATCH (n) DETACH DELETE n"
    try:
        self.session.run(cypher)
        logger.info("All nodes and relationships have been deleted from Neo4j.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to delete all nodes and relationships: {e}")
        raise StorageError(f"Failed to delete all nodes and relationships: {e}")
delete_all_edges()

Deletes all relationships from the Neo4j graph database without deleting nodes.

Raises:

Type Description
StorageError

If there is an issue deleting relationships.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def delete_all_edges(self) -> None:
    """
    Deletes all relationships from the Neo4j graph database without deleting nodes.

    Raises:
        StorageError: If there is an issue deleting relationships.
    """
    cypher = "MATCH ()-[r]->() DELETE r"
    try:
        self.session.run(cypher)
        logger.info("All relationships have been deleted from Neo4j.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to delete all relationships: {e}")
        raise StorageError(f"Failed to delete all relationships: {e}")
delete_edge(source_id, target_id, relationship)

Deletes a specific relationship between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship to delete.

required

Raises:

Type Description
StorageError

If there is an issue deleting the relationship.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def delete_edge(
    self,
    source_id: str,
    target_id: str,
    relationship: str
) -> None:
    """
    Deletes a specific relationship between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship to delete.

    Raises:
        StorageError: If there is an issue deleting the relationship.
    """
    cypher = (
        "MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) "
        "DELETE r"
    ) % relationship
    params = {
        'source_id': source_id,
        'target_id': target_id
    }
    try:
        result = self.session.run(cypher, params)
        if result.consume().counters.relationships_deleted == 0:
            logger.warning(f"No relationship '{relationship}' found between '{source_id}' and '{target_id}'.")
            raise StorageError(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
        logger.info(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' deleted.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to delete relationship: {e}")
        raise StorageError(f"Failed to delete relationship: {e}")
delete_node(node_id)

Deletes a node from the graph.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue deleting the node.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def delete_node(self, node_id: str) -> None:
    """
    Deletes a node from the graph.

    Args:
        node_id (str): Unique identifier of the node.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue deleting the node.
    """
    cypher = "MATCH (n {id: $node_id}) DETACH DELETE n RETURN COUNT(n) AS count"
    params = {'node_id': node_id}
    try:
        result = self.session.run(cypher, params)
        record = result.single()
        if record and record['count'] > 0:
            logger.info(f"Node with id '{node_id}' deleted.")
        else:
            logger.warning(f"Node with id '{node_id}' not found.")
            raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to delete node: {e}")
        raise StorageError(f"Failed to delete node: {e}")
delete_relationships_by_type(relationship)

Deletes all relationships of a specific type from the Neo4j graph database.

Parameters:

Name Type Description Default
relationship str

The type of relationships to delete.

required

Raises:

Type Description
StorageError

If there is an issue deleting the relationships.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def delete_relationships_by_type(self, relationship: str) -> None:
    """
    Deletes all relationships of a specific type from the Neo4j graph database.

    Args:
        relationship (str): The type of relationships to delete.

    Raises:
        StorageError: If there is an issue deleting the relationships.
    """
    cypher = f"MATCH ()-[r:{relationship}]->() DELETE r"
    try:
        self.session.run(cypher)
        logger.info(f"All relationships of type '{relationship}' have been deleted from Neo4j.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to delete relationships of type '{relationship}': {e}")
        raise StorageError(f"Failed to delete relationships of type '{relationship}': {e}")
execute_query(query, parameters=None)

Executes a raw query against the graph database.

Parameters:

Name Type Description Default
query str

The query string.

required
parameters Optional[Dict[str, Any]]

Parameters for parameterized queries.

None

Returns:

Name Type Description
Any Any

The result of the query.

Raises:

Type Description
StorageError

If there is an issue executing the query.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:
    """
    Executes a raw query against the graph database.

    Args:
        query (str): The query string.
        parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

    Returns:
        Any: The result of the query.

    Raises:
        StorageError: If there is an issue executing the query.
    """
    try:
        result = self.session.run(query, parameters)
        records = [record.data() for record in result]
        logger.info(f"Executed query: {query}")
        return records
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to execute query: {e}")
        raise StorageError(f"Failed to execute query: {e}")
get_neighbors(node_id, relationship=None, direction='both')

Retrieves neighboring nodes connected by edges.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required
relationship Optional[str]

Filter by relationship type.

None
direction str

Direction of the relationships ('in', 'out', 'both').

'both'

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of neighboring nodes.

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue retrieving neighbors.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def get_neighbors(
    self,
    node_id: str,
    relationship: Optional[str] = None,
    direction: str = "both"
) -> List[Dict[str, Any]]:
    """
    Retrieves neighboring nodes connected by edges.

    Args:
        node_id (str): Unique identifier of the node.
        relationship (Optional[str]): Filter by relationship type.
        direction (str): Direction of the relationships ('in', 'out', 'both').

    Returns:
        List[Dict[str, Any]]: A list of neighboring nodes.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue retrieving neighbors.
    """
    if direction not in ["in", "out", "both"]:
        raise ValueError("Invalid direction. Must be 'in', 'out', or 'both'.")

    rel_type = f":{relationship}" if relationship else ''
    if direction == "in":
        pattern = f"<-[r{rel_type}]-"
    elif direction == "out":
        pattern = f"-[r{rel_type}]->"
    else:  # both
        pattern = f"-[r{rel_type}]-"

    cypher = f"MATCH (n {{id: $node_id}}){pattern}(neighbor) RETURN neighbor"
    params = {'node_id': node_id}
    try:
        # First, check if the node exists
        node_exists_query = "MATCH (n {id: $node_id}) RETURN n"
        node_result = self.session.run(node_exists_query, params)
        if not node_result.single():
            logger.warning(f"Node with id '{node_id}' not found.")
            raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
        # Get neighbors
        result = self.session.run(cypher, params)
        neighbors = []
        for record in result:
            node = record['neighbor']
            neighbor_data = {
                'id': node['id'],
                'properties': {k: v for k, v in node.items() if k != 'id'},
                'labels': list(node.labels)
            }
            neighbors.append(neighbor_data)
        logger.info(f"Neighbors of node '{node_id}' retrieved.")
        return neighbors
    except NodeNotFoundError:
        raise
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to get neighbors: {e}")
        raise StorageError(f"Failed to get neighbors: {e}")
get_node(node_id)

Retrieves a node by its identifier.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the node's properties and labels.

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue retrieving the node.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def get_node(self, node_id: str) -> Dict[str, Any]:
    """
    Retrieves a node by its identifier.

    Args:
        node_id (str): Unique identifier of the node.

    Returns:
        Dict[str, Any]: A dictionary containing the node's properties and labels.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue retrieving the node.
    """
    cypher = "MATCH (n {id: $node_id}) RETURN n"
    params = {'node_id': node_id}
    try:
        result = self.session.run(cypher, params)
        record = result.single()
        if record:
            node = record['n']
            node_data = {
                'id': node['id'],
                'properties': {k: v for k, v in node.items() if k != 'id'},
                'labels': list(node.labels)
            }
            logger.info(f"Node with id '{node_id}' retrieved.")
            return node_data
        else:
            logger.warning(f"Node with id '{node_id}' not found.")
            raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to get node: {e}")
        raise StorageError(f"Failed to get node: {e}")
get_relationship(source_id, target_id, relationship)

Retrieves a specific relationship between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship to retrieve.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the relationship's properties.

Raises:

Type Description
StorageError

If there is an issue retrieving the relationship.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def get_relationship(
    self,
    source_id: str,
    target_id: str,
    relationship: str
) -> Dict[str, Any]:
    """
    Retrieves a specific relationship between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship to retrieve.

    Returns:
        Dict[str, Any]: A dictionary containing the relationship's properties.

    Raises:
        StorageError: If there is an issue retrieving the relationship.
    """
    cypher = (
        "MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) "
        "RETURN r"
    ) % relationship
    params = {
        'source_id': source_id,
        'target_id': target_id
    }
    try:
        result = self.session.run(cypher, params)
        record = result.single()
        if record:
            relationship_data = record['r']
            properties = dict(relationship_data)
            properties['type'] = relationship.type  # Include relationship type
            logger.info(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' retrieved.")
            return properties
        else:
            logger.warning(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
            raise StorageError(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to retrieve relationship: {e}")
        raise StorageError(f"Failed to retrieve relationship: {e}")
query_nodes(properties, labels=None)

Queries nodes based on properties and labels.

Parameters:

Name Type Description Default
properties Dict[str, Any]

Properties to filter nodes.

required
labels Optional[List[str]]

Labels to filter nodes.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of nodes matching the query.

Raises:

Type Description
StorageError

If there is an issue querying nodes.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
def query_nodes(
    self,
    properties: Dict[str, Any],
    labels: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
    """
    Queries nodes based on properties and labels.

    Args:
        properties (Dict[str, Any]): Properties to filter nodes.
        labels (Optional[List[str]]): Labels to filter nodes.

    Returns:
        List[Dict[str, Any]]: A list of nodes matching the query.

    Raises:
        StorageError: If there is an issue querying nodes.
    """
    labels_str = ':' + ':'.join(labels) if labels else ''
    params = {}
    cypher = f"MATCH (n{labels_str})"

    if properties:
        props_conditions = ' AND '.join([f"n.{key} = ${key}" for key in properties.keys()])
        cypher += f" WHERE {props_conditions}"
        params.update(properties)

    cypher += " RETURN n"

    try:
        result = self.session.run(cypher, params)
        nodes = []
        for record in result:
            node = record['n']
            node_data = {
                'id': node['id'],
                'properties': {k: v for k, v in node.items() if k != 'id'},
                'labels': list(node.labels)
            }
            nodes.append(node_data)
        logger.info(f"Query returned {len(nodes)} nodes.")
        return nodes
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to query nodes: {e}")
        raise StorageError(f"Failed to query nodes: {e}")
update_edge(source_id, target_id, relationship, properties)

Updates properties of a specific relationship between two nodes.

Parameters:

Name Type Description Default
source_id str

Unique identifier of the source node.

required
target_id str

Unique identifier of the target node.

required
relationship str

Type of the relationship to update.

required
properties Dict[str, Any]

Properties to update on the relationship.

required

Raises:

Type Description
StorageError

If there is an issue updating the relationship.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def update_edge(
    self,
    source_id: str,
    target_id: str,
    relationship: str,
    properties: Dict[str, Any]
) -> None:
    """
    Updates properties of a specific relationship between two nodes.

    Args:
        source_id (str): Unique identifier of the source node.
        target_id (str): Unique identifier of the target node.
        relationship (str): Type of the relationship to update.
        properties (Dict[str, Any]): Properties to update on the relationship.

    Raises:
        StorageError: If there is an issue updating the relationship.
    """
    cypher = (
        "MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) "
        "SET r += $properties RETURN r"
    ) % relationship
    params = {
        'source_id': source_id,
        'target_id': target_id,
        'properties': properties
    }
    try:
        result = self.session.run(cypher, params)
        record = result.single()
        if record:
            logger.info(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' updated with properties {properties}.")
        else:
            logger.warning(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
            raise StorageError(f"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to update relationship: {e}")
        raise StorageError(f"Failed to update relationship: {e}")
update_node(node_id, properties)

Updates properties of a node.

Parameters:

Name Type Description Default
node_id str

Unique identifier of the node.

required
properties Dict[str, Any]

Properties to update.

required

Raises:

Type Description
NodeNotFoundError

If the node does not exist.

StorageError

If there is an issue updating the node.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:
    """
    Updates properties of a node.

    Args:
        node_id (str): Unique identifier of the node.
        properties (Dict[str, Any]): Properties to update.

    Raises:
        NodeNotFoundError: If the node does not exist.
        StorageError: If there is an issue updating the node.
    """
    cypher = "MATCH (n {id: $node_id}) SET n += $properties RETURN n"
    params = {
        'node_id': node_id,
        'properties': properties
    }
    try:
        result = self.session.run(cypher, params)
        record = result.single()
        if record:
            logger.info(f"Node with id '{node_id}' updated.")
        else:
            logger.warning(f"Node with id '{node_id}' not found.")
            raise NodeNotFoundError(f"Node with id '{node_id}' not found.")
    except exceptions.Neo4jError as e:
        logger.error(f"Failed to update node: {e}")
        raise StorageError(f"Failed to update node: {e}")
NodeNotFoundError

Bases: Exception

Exception raised when a node is not found in the graph database.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
11
12
13
class NodeNotFoundError(Exception):
    """Exception raised when a node is not found in the graph database."""
    pass
StorageError

Bases: Exception

Exception raised when there is a storage-related error in the graph database.

Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
16
17
18
class StorageError(Exception):
    """Exception raised when there is a storage-related error in the graph database."""
    pass

pgvector

pgvector_config

PGVectorConfig dataclass

Bases: BaseConfig

Configuration for PGVector (PostgreSQL with vector extension).

Source code in src/aeiva/storage/pgvector/pgvector_config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@dataclass
class PGVectorConfig(BaseConfig):
    """
    Configuration for PGVector (PostgreSQL with vector extension).
    """

    dbname: str = field(
        default="postgres",
        metadata={"help": "Name of the database."}
    )
    collection_name: str = field(
        default="mem0",
        metadata={"help": "Name of the collection (table name)."}
    )
    embedding_model_dims: int = field(
        default=1536,
        metadata={"help": "Dimensions of the embedding model."}
    )
    user: Optional[str] = field(
        default=None,
        metadata={"help": "Database user."}
    )
    password: Optional[str] = field(
        default=None,
        metadata={"help": "Database password."}
    )
    host: str = field(
        default="localhost",
        metadata={"help": "Database host."}
    )
    port: int = field(
        default=5432,
        metadata={"help": "Database port."}
    )
    use_diskann: bool = field(
        default=True,
        metadata={"help": "Whether to use diskann for approximate nearest neighbors search."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Validate that user and password are provided
        if not self.user or not self.password:
            raise ValueError("Both 'user' and 'password' must be provided.")

pgvector_database

PGVectorDatabase

Bases: VectorDatabase

Concrete implementation of VectorStoreBase using PGVector.

Source code in src/aeiva/storage/pgvector/pgvector_database.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
class PGVectorDatabase(VectorDatabase):
    """
    Concrete implementation of VectorStoreBase using PGVector.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the PGVector vector store.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.collection_name = config.get('collection_name')
        self.dbname = config.get('dbname')
        self.user = config.get('user')
        self.password = config.get('password')
        self.host = config.get('host', 'localhost')
        self.port = config.get('port', 5432)
        self.embedding_model_dims = config.get('embedding_model_dims')
        self.use_diskann = config.get('use_diskann', False)

        if not all([self.collection_name, self.dbname, self.user, self.password, self.embedding_model_dims]):
            raise ValueError("Required configuration parameters are missing.")

        self.create_client()
        self.create_collection(
            collection_name=self.collection_name,
            vector_size=self.embedding_model_dims,
            distance_metric='cosine'  # PGVector uses cosine by default
        )

    def create_client(self, **kwargs) -> None:
        """
        Initializes the client connection to the PGVector database.

        Args:
            **kwargs: Additional parameters.
        """
        try:
            self.conn = psycopg2.connect(
                dbname=self.dbname,
                user=self.user,
                password=self.password,
                host=self.host,
                port=self.port,
                **kwargs
            )
            self.cur = self.conn.cursor()
            logger.info("Connected to PGVector database.")
        except psycopg2.Error as e:
            logger.error(f"Failed to connect to PGVector database: {e}")
            raise ConnectionError(f"Failed to connect to PGVector database: {e}")

    def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
        """
        Create a new vector collection (table) in PGVector.

        Args:
            collection_name (str): The name of the collection.
            vector_size (int): The dimensionality of the vectors.
            distance_metric (str): The distance metric to use (e.g., 'cosine').
        """
        # Check if table exists
        self.cur.execute(
            "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=%s);",
            (collection_name,)
        )
        exists = self.cur.fetchone()[0]
        if exists:
            logger.info(f"Table {collection_name} already exists. Skipping creation.")
            return

        # Create table
        create_table_query = f"""
        CREATE TABLE {collection_name} (
            id VARCHAR(64) PRIMARY KEY,
            vector vector({vector_size}),
            payload JSONB
        );
        """
        self.cur.execute(create_table_query)
        self.conn.commit()
        logger.info(f"Table {collection_name} created successfully.")

        # Create index if use_diskann is True
        if self.use_diskann:
            create_index_query = f"""
            CREATE INDEX {collection_name}_vector_idx
            ON {collection_name}
            USING ivfflat (vector vector_cosine_ops)
            WITH (lists = 100);
            """
            self.cur.execute(create_index_query)
            self.conn.commit()
            logger.info(f"Index created on table {collection_name}.")

    def insert_vectors(
        self,
        collection_name: str,
        vectors: List[List[float]],
        payloads: Optional[List[Dict[str, Any]]] = None,
        ids: Optional[List[str]] = None
    ) -> None:
        """
        Insert vectors into a collection.

        Args:
            collection_name (str): The name of the collection.
            vectors (List[List[float]]): A list of vectors to insert.
            payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
            ids (Optional[List[str]]): Optional unique identifiers for each vector.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        if ids is None:
            raise ValueError("PGVector requires IDs to be provided for each vector.")
        if payloads is None:
            payloads = [{} for _ in range(len(vectors))]
        if not (len(ids) == len(vectors) == len(payloads)):
            raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

        records = [
            (id_, vector, Json(payload))
            for id_, vector, payload in zip(ids, vectors, payloads)
        ]
        insert_query = f"INSERT INTO {collection_name} (id, vector, payload) VALUES %s;"
        execute_values(self.cur, insert_query, records)
        self.conn.commit()
        logger.info(f"Inserted {len(vectors)} vectors into table {collection_name}.")

    def search_vectors(
        self,
        collection_name: str,
        query_vector: List[float],
        top_k: int = 5,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """
        Search for similar vectors in a collection.

        Args:
            collection_name (str): The name of the collection.
            query_vector (List[float]): The vector to search with.
            top_k (int): The number of top results to return.
            filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

        Returns:
            List[Dict[str, Any]]: A list of search results.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        filter_clause = ""
        params = [query_vector]

        if filters:
            filter_conditions = []
            for key, value in filters.items():
                filter_conditions.append(f"payload ->> %s = %s")
                params.extend([key, str(value)])
            filter_clause = "WHERE " + " AND ".join(filter_conditions)

        search_query = f"""
        SELECT id, vector, payload, 1 - (vector <#> %s::vector) AS score
        FROM {collection_name}
        {filter_clause}
        ORDER BY vector <#> %s::vector
        LIMIT %s;
        """
        params.extend([query_vector, top_k])
        self.cur.execute(search_query, params)
        results = self.cur.fetchall()

        output = []
        for row in results:
            result = {
                'id': row[0],
                'score': row[3],
                'payload': row[2]
            }
            output.append(result)
        return output

    def delete_vector(self, collection_name: str, vector_id: str) -> None:
        """
        Delete a vector from a collection by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to delete.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        delete_query = f"DELETE FROM {collection_name} WHERE id = %s;"
        self.cur.execute(delete_query, (vector_id,))
        self.conn.commit()
        logger.info(f"Deleted vector with ID {vector_id} from table {collection_name}.")

    def update_vector(
        self,
        collection_name: str,
        vector_id: str,
        vector: Optional[List[float]] = None,
        payload: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Update a vector's data or payload.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to update.
            vector (Optional[List[float]]): The new vector data.
            payload (Optional[Dict[str, Any]]): The new payload data.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        if vector is not None:
            update_query = f"UPDATE {collection_name} SET vector = %s WHERE id = %s;"
            self.cur.execute(update_query, (vector, vector_id))
        if payload is not None:
            update_query = f"UPDATE {collection_name} SET payload = %s WHERE id = %s;"
            self.cur.execute(update_query, (Json(payload), vector_id))
        self.conn.commit()
        logger.info(f"Updated vector with ID {vector_id} in table {collection_name}.")

    def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
        """
        Retrieve a vector by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector.

        Returns:
            Dict[str, Any]: A dictionary containing the vector data and payload.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        select_query = f"SELECT id, vector, payload FROM {collection_name} WHERE id = %s;"
        self.cur.execute(select_query, (vector_id,))
        result = self.cur.fetchone()

        if not result:
            raise KeyError(f"Vector with ID {vector_id} not found in table {collection_name}.")

        vector_data = {
            'id': result[0],
            'vector': result[1],
            'payload': result[2]
        }
        return vector_data

    def list_collections(self) -> List[str]:
        """
        List all available vector collections (tables).

        Returns:
            List[str]: A list of collection names.
        """
        self.cur.execute(
            "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
        )
        tables = self.cur.fetchall()
        return [table[0] for table in tables]

    def delete_collection(self, collection_name: str) -> None:
        """
        Delete an entire vector collection.

        Args:
            collection_name (str): The name of the collection to delete.
        """
        drop_query = f"DROP TABLE IF EXISTS {collection_name};"
        self.cur.execute(drop_query)
        self.conn.commit()
        logger.info(f"Deleted table {collection_name}.")

    def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        """
        Get information about a collection.

        Args:
            collection_name (str): The name of the collection.

        Returns:
            Dict[str, Any]: Information about the collection.
        """
        self.cur.execute(
            "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;",
            (collection_name,)
        )
        columns = self.cur.fetchall()
        info = {
            'name': collection_name,
            'columns': {column[0]: column[1] for column in columns}
        }
        return info

    def __del__(self):
        """Clean up resources."""
        if hasattr(self, 'cur') and self.cur:
            self.cur.close()
        if hasattr(self, 'conn') and self.conn:
            self.conn.close()
        logger.info("Closed connection to PGVector database.")
__del__()

Clean up resources.

Source code in src/aeiva/storage/pgvector/pgvector_database.py
317
318
319
320
321
322
323
def __del__(self):
    """Clean up resources."""
    if hasattr(self, 'cur') and self.cur:
        self.cur.close()
    if hasattr(self, 'conn') and self.conn:
        self.conn.close()
    logger.info("Closed connection to PGVector database.")
__init__(config)

Initialize the PGVector vector store.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/pgvector/pgvector_database.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the PGVector vector store.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.collection_name = config.get('collection_name')
    self.dbname = config.get('dbname')
    self.user = config.get('user')
    self.password = config.get('password')
    self.host = config.get('host', 'localhost')
    self.port = config.get('port', 5432)
    self.embedding_model_dims = config.get('embedding_model_dims')
    self.use_diskann = config.get('use_diskann', False)

    if not all([self.collection_name, self.dbname, self.user, self.password, self.embedding_model_dims]):
        raise ValueError("Required configuration parameters are missing.")

    self.create_client()
    self.create_collection(
        collection_name=self.collection_name,
        vector_size=self.embedding_model_dims,
        distance_metric='cosine'  # PGVector uses cosine by default
    )
create_client(**kwargs)

Initializes the client connection to the PGVector database.

Parameters:

Name Type Description Default
**kwargs

Additional parameters.

{}
Source code in src/aeiva/storage/pgvector/pgvector_database.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def create_client(self, **kwargs) -> None:
    """
    Initializes the client connection to the PGVector database.

    Args:
        **kwargs: Additional parameters.
    """
    try:
        self.conn = psycopg2.connect(
            dbname=self.dbname,
            user=self.user,
            password=self.password,
            host=self.host,
            port=self.port,
            **kwargs
        )
        self.cur = self.conn.cursor()
        logger.info("Connected to PGVector database.")
    except psycopg2.Error as e:
        logger.error(f"Failed to connect to PGVector database: {e}")
        raise ConnectionError(f"Failed to connect to PGVector database: {e}")
create_collection(collection_name, vector_size, distance_metric)

Create a new vector collection (table) in PGVector.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_size int

The dimensionality of the vectors.

required
distance_metric str

The distance metric to use (e.g., 'cosine').

required
Source code in src/aeiva/storage/pgvector/pgvector_database.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
    """
    Create a new vector collection (table) in PGVector.

    Args:
        collection_name (str): The name of the collection.
        vector_size (int): The dimensionality of the vectors.
        distance_metric (str): The distance metric to use (e.g., 'cosine').
    """
    # Check if table exists
    self.cur.execute(
        "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=%s);",
        (collection_name,)
    )
    exists = self.cur.fetchone()[0]
    if exists:
        logger.info(f"Table {collection_name} already exists. Skipping creation.")
        return

    # Create table
    create_table_query = f"""
    CREATE TABLE {collection_name} (
        id VARCHAR(64) PRIMARY KEY,
        vector vector({vector_size}),
        payload JSONB
    );
    """
    self.cur.execute(create_table_query)
    self.conn.commit()
    logger.info(f"Table {collection_name} created successfully.")

    # Create index if use_diskann is True
    if self.use_diskann:
        create_index_query = f"""
        CREATE INDEX {collection_name}_vector_idx
        ON {collection_name}
        USING ivfflat (vector vector_cosine_ops)
        WITH (lists = 100);
        """
        self.cur.execute(create_index_query)
        self.conn.commit()
        logger.info(f"Index created on table {collection_name}.")
delete_collection(collection_name)

Delete an entire vector collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection to delete.

required
Source code in src/aeiva/storage/pgvector/pgvector_database.py
284
285
286
287
288
289
290
291
292
293
294
def delete_collection(self, collection_name: str) -> None:
    """
    Delete an entire vector collection.

    Args:
        collection_name (str): The name of the collection to delete.
    """
    drop_query = f"DROP TABLE IF EXISTS {collection_name};"
    self.cur.execute(drop_query)
    self.conn.commit()
    logger.info(f"Deleted table {collection_name}.")
delete_vector(collection_name, vector_id)

Delete a vector from a collection by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to delete.

required
Source code in src/aeiva/storage/pgvector/pgvector_database.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def delete_vector(self, collection_name: str, vector_id: str) -> None:
    """
    Delete a vector from a collection by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to delete.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    delete_query = f"DELETE FROM {collection_name} WHERE id = %s;"
    self.cur.execute(delete_query, (vector_id,))
    self.conn.commit()
    logger.info(f"Deleted vector with ID {vector_id} from table {collection_name}.")
get_collection_info(collection_name)

Get information about a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Information about the collection.

Source code in src/aeiva/storage/pgvector/pgvector_database.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
    """
    Get information about a collection.

    Args:
        collection_name (str): The name of the collection.

    Returns:
        Dict[str, Any]: Information about the collection.
    """
    self.cur.execute(
        "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;",
        (collection_name,)
    )
    columns = self.cur.fetchall()
    info = {
        'name': collection_name,
        'columns': {column[0]: column[1] for column in columns}
    }
    return info
get_vector(collection_name, vector_id)

Retrieve a vector by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the vector data and payload.

Source code in src/aeiva/storage/pgvector/pgvector_database.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
    """
    Retrieve a vector by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector.

    Returns:
        Dict[str, Any]: A dictionary containing the vector data and payload.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    select_query = f"SELECT id, vector, payload FROM {collection_name} WHERE id = %s;"
    self.cur.execute(select_query, (vector_id,))
    result = self.cur.fetchone()

    if not result:
        raise KeyError(f"Vector with ID {vector_id} not found in table {collection_name}.")

    vector_data = {
        'id': result[0],
        'vector': result[1],
        'payload': result[2]
    }
    return vector_data
insert_vectors(collection_name, vectors, payloads=None, ids=None)

Insert vectors into a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vectors List[List[float]]

A list of vectors to insert.

required
payloads Optional[List[Dict[str, Any]]]

Optional metadata associated with each vector.

None
ids Optional[List[str]]

Optional unique identifiers for each vector.

None
Source code in src/aeiva/storage/pgvector/pgvector_database.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def insert_vectors(
    self,
    collection_name: str,
    vectors: List[List[float]],
    payloads: Optional[List[Dict[str, Any]]] = None,
    ids: Optional[List[str]] = None
) -> None:
    """
    Insert vectors into a collection.

    Args:
        collection_name (str): The name of the collection.
        vectors (List[List[float]]): A list of vectors to insert.
        payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
        ids (Optional[List[str]]): Optional unique identifiers for each vector.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    if ids is None:
        raise ValueError("PGVector requires IDs to be provided for each vector.")
    if payloads is None:
        payloads = [{} for _ in range(len(vectors))]
    if not (len(ids) == len(vectors) == len(payloads)):
        raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

    records = [
        (id_, vector, Json(payload))
        for id_, vector, payload in zip(ids, vectors, payloads)
    ]
    insert_query = f"INSERT INTO {collection_name} (id, vector, payload) VALUES %s;"
    execute_values(self.cur, insert_query, records)
    self.conn.commit()
    logger.info(f"Inserted {len(vectors)} vectors into table {collection_name}.")
list_collections()

List all available vector collections (tables).

Returns:

Type Description
List[str]

List[str]: A list of collection names.

Source code in src/aeiva/storage/pgvector/pgvector_database.py
271
272
273
274
275
276
277
278
279
280
281
282
def list_collections(self) -> List[str]:
    """
    List all available vector collections (tables).

    Returns:
        List[str]: A list of collection names.
    """
    self.cur.execute(
        "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
    )
    tables = self.cur.fetchall()
    return [table[0] for table in tables]
search_vectors(collection_name, query_vector, top_k=5, filters=None)

Search for similar vectors in a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
query_vector List[float]

The vector to search with.

required
top_k int

The number of top results to return.

5
filters Optional[Dict[str, Any]]

Optional filters to apply to the search.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of search results.

Source code in src/aeiva/storage/pgvector/pgvector_database.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def search_vectors(
    self,
    collection_name: str,
    query_vector: List[float],
    top_k: int = 5,
    filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
    """
    Search for similar vectors in a collection.

    Args:
        collection_name (str): The name of the collection.
        query_vector (List[float]): The vector to search with.
        top_k (int): The number of top results to return.
        filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

    Returns:
        List[Dict[str, Any]]: A list of search results.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    filter_clause = ""
    params = [query_vector]

    if filters:
        filter_conditions = []
        for key, value in filters.items():
            filter_conditions.append(f"payload ->> %s = %s")
            params.extend([key, str(value)])
        filter_clause = "WHERE " + " AND ".join(filter_conditions)

    search_query = f"""
    SELECT id, vector, payload, 1 - (vector <#> %s::vector) AS score
    FROM {collection_name}
    {filter_clause}
    ORDER BY vector <#> %s::vector
    LIMIT %s;
    """
    params.extend([query_vector, top_k])
    self.cur.execute(search_query, params)
    results = self.cur.fetchall()

    output = []
    for row in results:
        result = {
            'id': row[0],
            'score': row[3],
            'payload': row[2]
        }
        output.append(result)
    return output
update_vector(collection_name, vector_id, vector=None, payload=None)

Update a vector's data or payload.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to update.

required
vector Optional[List[float]]

The new vector data.

None
payload Optional[Dict[str, Any]]

The new payload data.

None
Source code in src/aeiva/storage/pgvector/pgvector_database.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def update_vector(
    self,
    collection_name: str,
    vector_id: str,
    vector: Optional[List[float]] = None,
    payload: Optional[Dict[str, Any]] = None
) -> None:
    """
    Update a vector's data or payload.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to update.
        vector (Optional[List[float]]): The new vector data.
        payload (Optional[Dict[str, Any]]): The new payload data.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    if vector is not None:
        update_query = f"UPDATE {collection_name} SET vector = %s WHERE id = %s;"
        self.cur.execute(update_query, (vector, vector_id))
    if payload is not None:
        update_query = f"UPDATE {collection_name} SET payload = %s WHERE id = %s;"
        self.cur.execute(update_query, (Json(payload), vector_id))
    self.conn.commit()
    logger.info(f"Updated vector with ID {vector_id} in table {collection_name}.")

postgresql

postgresql_config

PostgreSQLConfig dataclass

Bases: BaseConfig

Configuration for PostgreSQL database.

Source code in src/aeiva/storage/postgresql/postgresql_config.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@dataclass
class PostgreSQLConfig(BaseConfig):
    """
    Configuration for PostgreSQL database.
    """
    dbname: str = field(
        default='postgres',
        metadata={"help": "Name of the PostgreSQL database."}
    )
    user: str = field(
        default='postgres',
        metadata={"help": "Username for PostgreSQL authentication."}
    )
    password: str = field(
        default='',
        metadata={"help": "Password for PostgreSQL authentication."}
    )
    host: str = field(
        default='localhost',
        metadata={"help": "Host address for PostgreSQL server."}
    )
    port: int = field(
        default=5432,
        metadata={"help": "Port number for PostgreSQL server."}
    )

postgresql_database

PostgreSQLDatabase

Bases: RelationalDatabase

Concrete implementation of RelationalStoreBase using PostgreSQL.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
class PostgreSQLDatabase(RelationalDatabase):
    """
    Concrete implementation of RelationalStoreBase using PostgreSQL.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the PostgreSQL database connection.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.connection = None
        self.cursor = None
        self.connect()

    def connect(self) -> None:
        """
        Establishes a connection to the PostgreSQL database.
        """
        try:
            self.connection = psycopg2.connect(
                dbname=self.config.get('dbname'),
                user=self.config.get('user'),
                password=self.config.get('password'),
                host=self.config.get('host'),
                port=self.config.get('port')
            )
            self.connection.autocommit = True  # Enable autocommit for DDL statements
            self.cursor = self.connection.cursor(cursor_factory=RealDictCursor)
        except psycopg2.Error as e:
            raise ConnectionError(f"Failed to connect to PostgreSQL database: {e}")

    def close(self) -> None:
        """
        Closes the database connection and releases resources.
        """
        if self.cursor:
            self.cursor.close()
        if self.connection:
            self.connection.close()

    def insert_record(self, table: str, record: Dict[str, Any]) -> Any:
        """
        Inserts a record into a table.

        Args:
            table (str): The name of the table.
            record (Dict[str, Any]): A dictionary representing the record to insert.

        Returns:
            Any: The primary key of the inserted record.

        Raises:
            StorageError: If there is an issue inserting the record.
        """
        try:
            columns = ', '.join(record.keys())
            placeholders = ', '.join(f"%({key})s" for key in record.keys())
            sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) RETURNING id"
            self.cursor.execute(sql, record)
            result = self.cursor.fetchone()
            return result['id']
        except psycopg2.IntegrityError as e:
            self.connection.rollback()
            raise StorageError(f"Integrity error: {e}")
        except psycopg2.Error as e:
            self.connection.rollback()
            raise StorageError(f"Failed to insert record: {e}")

    def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:
        """
        Retrieves a record by its primary key.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.

        Returns:
            Dict[str, Any]: The retrieved record.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue retrieving the record.
        """
        try:
            sql = f"SELECT * FROM {table} WHERE id = %s"
            self.cursor.execute(sql, (primary_key,))
            row = self.cursor.fetchone()
            if row is None:
                raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
            return dict(row)
        except psycopg2.Error as e:
            raise StorageError(f"Failed to get record: {e}")

    def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:
        """
        Updates a record in a table.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.
            updates (Dict[str, Any]): A dictionary of fields to update.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue updating the record.
        """
        try:
            set_clause = ', '.join(f"{key} = %({key})s" for key in updates.keys())
            sql = f"UPDATE {table} SET {set_clause} WHERE id = %(id)s"
            updates['id'] = primary_key
            self.cursor.execute(sql, updates)
            if self.cursor.rowcount == 0:
                self.connection.rollback()
                raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
        except psycopg2.Error as e:
            self.connection.rollback()
            raise StorageError(f"Failed to update record: {e}")

    def delete_record(self, table: str, primary_key: Any) -> None:
        """
        Deletes a record from a table.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue deleting the record.
        """
        try:
            sql = f"DELETE FROM {table} WHERE id = %s"
            self.cursor.execute(sql, (primary_key,))
            if self.cursor.rowcount == 0:
                self.connection.rollback()
                raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
        except psycopg2.Error as e:
            self.connection.rollback()
            raise StorageError(f"Failed to delete record: {e}")

    def query_records(
        self,
        table: str,
        conditions: Optional[Dict[str, Any]] = None,
        limit: Optional[int] = None,
        offset: Optional[int] = None
    ) -> List[Dict[str, Any]]:
        """
        Queries records from a table based on conditions.

        Args:
            table (str): The name of the table.
            conditions (Optional[Dict[str, Any]]): Conditions to filter records.
            limit (Optional[int]): Maximum number of records to return.
            offset (Optional[int]): Number of records to skip.

        Returns:
            List[Dict[str, Any]]: A list of records matching the query.

        Raises:
            StorageError: If there is an issue querying records.
        """
        try:
            sql = f"SELECT * FROM {table}"
            params = {}
            if conditions:
                where_clause = ' AND '.join(f"{key} = %({key})s" for key in conditions.keys())
                sql += f" WHERE {where_clause}"
                params.update(conditions)
            if limit is not None:
                sql += f" LIMIT {limit}"
            if offset is not None:
                sql += f" OFFSET {offset}"
            self.cursor.execute(sql, params)
            rows = self.cursor.fetchall()
            return [dict(row) for row in rows]
        except psycopg2.Error as e:
            raise StorageError(f"Failed to query records: {e}")

    def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:
        """
        Executes a raw SQL query.

        Args:
            query (str): The SQL query string.
            parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

        Returns:
            Any: The result of the query.

        Raises:
            StorageError: If there is an issue executing the query.
        """
        cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
        try:
            if parameters:
                cursor.execute(query, parameters)
            else:
                cursor.execute(query)
            if query.strip().upper().startswith("SELECT"):
                return cursor
            else:
                self.connection.commit()
                return cursor
        except psycopg2.Error as e:
            self.connection.rollback()
            raise StorageError(f"Failed to execute SQL query: {e}")

    def begin_transaction(self) -> None:
        """
        Begins a transaction.
        """
        self.connection.autocommit = False

    def commit_transaction(self) -> None:
        """
        Commits the current transaction.
        """
        self.connection.commit()
        self.connection.autocommit = True

    def rollback_transaction(self) -> None:
        """
        Rolls back the current transaction.
        """
        self.connection.rollback()
        self.connection.autocommit = True
__init__(config)

Initialize the PostgreSQL database connection.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/postgresql/postgresql_database.py
24
25
26
27
28
29
30
31
32
33
34
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the PostgreSQL database connection.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.connection = None
    self.cursor = None
    self.connect()
begin_transaction()

Begins a transaction.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
230
231
232
233
234
def begin_transaction(self) -> None:
    """
    Begins a transaction.
    """
    self.connection.autocommit = False
close()

Closes the database connection and releases resources.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
53
54
55
56
57
58
59
60
def close(self) -> None:
    """
    Closes the database connection and releases resources.
    """
    if self.cursor:
        self.cursor.close()
    if self.connection:
        self.connection.close()
commit_transaction()

Commits the current transaction.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
236
237
238
239
240
241
def commit_transaction(self) -> None:
    """
    Commits the current transaction.
    """
    self.connection.commit()
    self.connection.autocommit = True
connect()

Establishes a connection to the PostgreSQL database.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def connect(self) -> None:
    """
    Establishes a connection to the PostgreSQL database.
    """
    try:
        self.connection = psycopg2.connect(
            dbname=self.config.get('dbname'),
            user=self.config.get('user'),
            password=self.config.get('password'),
            host=self.config.get('host'),
            port=self.config.get('port')
        )
        self.connection.autocommit = True  # Enable autocommit for DDL statements
        self.cursor = self.connection.cursor(cursor_factory=RealDictCursor)
    except psycopg2.Error as e:
        raise ConnectionError(f"Failed to connect to PostgreSQL database: {e}")
delete_record(table, primary_key)

Deletes a record from a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue deleting the record.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def delete_record(self, table: str, primary_key: Any) -> None:
    """
    Deletes a record from a table.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue deleting the record.
    """
    try:
        sql = f"DELETE FROM {table} WHERE id = %s"
        self.cursor.execute(sql, (primary_key,))
        if self.cursor.rowcount == 0:
            self.connection.rollback()
            raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
    except psycopg2.Error as e:
        self.connection.rollback()
        raise StorageError(f"Failed to delete record: {e}")
execute_sql(query, parameters=None)

Executes a raw SQL query.

Parameters:

Name Type Description Default
query str

The SQL query string.

required
parameters Optional[Dict[str, Any]]

Parameters for parameterized queries.

None

Returns:

Name Type Description
Any Any

The result of the query.

Raises:

Type Description
StorageError

If there is an issue executing the query.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:
    """
    Executes a raw SQL query.

    Args:
        query (str): The SQL query string.
        parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

    Returns:
        Any: The result of the query.

    Raises:
        StorageError: If there is an issue executing the query.
    """
    cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
    try:
        if parameters:
            cursor.execute(query, parameters)
        else:
            cursor.execute(query)
        if query.strip().upper().startswith("SELECT"):
            return cursor
        else:
            self.connection.commit()
            return cursor
    except psycopg2.Error as e:
        self.connection.rollback()
        raise StorageError(f"Failed to execute SQL query: {e}")
get_record(table, primary_key)

Retrieves a record by its primary key.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: The retrieved record.

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue retrieving the record.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:
    """
    Retrieves a record by its primary key.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.

    Returns:
        Dict[str, Any]: The retrieved record.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue retrieving the record.
    """
    try:
        sql = f"SELECT * FROM {table} WHERE id = %s"
        self.cursor.execute(sql, (primary_key,))
        row = self.cursor.fetchone()
        if row is None:
            raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
        return dict(row)
    except psycopg2.Error as e:
        raise StorageError(f"Failed to get record: {e}")
insert_record(table, record)

Inserts a record into a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
record Dict[str, Any]

A dictionary representing the record to insert.

required

Returns:

Name Type Description
Any Any

The primary key of the inserted record.

Raises:

Type Description
StorageError

If there is an issue inserting the record.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def insert_record(self, table: str, record: Dict[str, Any]) -> Any:
    """
    Inserts a record into a table.

    Args:
        table (str): The name of the table.
        record (Dict[str, Any]): A dictionary representing the record to insert.

    Returns:
        Any: The primary key of the inserted record.

    Raises:
        StorageError: If there is an issue inserting the record.
    """
    try:
        columns = ', '.join(record.keys())
        placeholders = ', '.join(f"%({key})s" for key in record.keys())
        sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) RETURNING id"
        self.cursor.execute(sql, record)
        result = self.cursor.fetchone()
        return result['id']
    except psycopg2.IntegrityError as e:
        self.connection.rollback()
        raise StorageError(f"Integrity error: {e}")
    except psycopg2.Error as e:
        self.connection.rollback()
        raise StorageError(f"Failed to insert record: {e}")
query_records(table, conditions=None, limit=None, offset=None)

Queries records from a table based on conditions.

Parameters:

Name Type Description Default
table str

The name of the table.

required
conditions Optional[Dict[str, Any]]

Conditions to filter records.

None
limit Optional[int]

Maximum number of records to return.

None
offset Optional[int]

Number of records to skip.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of records matching the query.

Raises:

Type Description
StorageError

If there is an issue querying records.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def query_records(
    self,
    table: str,
    conditions: Optional[Dict[str, Any]] = None,
    limit: Optional[int] = None,
    offset: Optional[int] = None
) -> List[Dict[str, Any]]:
    """
    Queries records from a table based on conditions.

    Args:
        table (str): The name of the table.
        conditions (Optional[Dict[str, Any]]): Conditions to filter records.
        limit (Optional[int]): Maximum number of records to return.
        offset (Optional[int]): Number of records to skip.

    Returns:
        List[Dict[str, Any]]: A list of records matching the query.

    Raises:
        StorageError: If there is an issue querying records.
    """
    try:
        sql = f"SELECT * FROM {table}"
        params = {}
        if conditions:
            where_clause = ' AND '.join(f"{key} = %({key})s" for key in conditions.keys())
            sql += f" WHERE {where_clause}"
            params.update(conditions)
        if limit is not None:
            sql += f" LIMIT {limit}"
        if offset is not None:
            sql += f" OFFSET {offset}"
        self.cursor.execute(sql, params)
        rows = self.cursor.fetchall()
        return [dict(row) for row in rows]
    except psycopg2.Error as e:
        raise StorageError(f"Failed to query records: {e}")
rollback_transaction()

Rolls back the current transaction.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
243
244
245
246
247
248
def rollback_transaction(self) -> None:
    """
    Rolls back the current transaction.
    """
    self.connection.rollback()
    self.connection.autocommit = True
update_record(table, primary_key, updates)

Updates a record in a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required
updates Dict[str, Any]

A dictionary of fields to update.

required

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue updating the record.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:
    """
    Updates a record in a table.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.
        updates (Dict[str, Any]): A dictionary of fields to update.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue updating the record.
    """
    try:
        set_clause = ', '.join(f"{key} = %({key})s" for key in updates.keys())
        sql = f"UPDATE {table} SET {set_clause} WHERE id = %(id)s"
        updates['id'] = primary_key
        self.cursor.execute(sql, updates)
        if self.cursor.rowcount == 0:
            self.connection.rollback()
            raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
    except psycopg2.Error as e:
        self.connection.rollback()
        raise StorageError(f"Failed to update record: {e}")
RecordNotFoundError

Bases: Exception

Exception raised when a record is not found in the database.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
 9
10
11
class RecordNotFoundError(Exception):
    """Exception raised when a record is not found in the database."""
    pass
StorageError

Bases: Exception

Exception raised when there is a storage-related error in the database.

Source code in src/aeiva/storage/postgresql/postgresql_database.py
14
15
16
class StorageError(Exception):
    """Exception raised when there is a storage-related error in the database."""
    pass

test

RecordNotFoundError

Bases: Exception

Exception raised when a record is not found in the database.

Source code in src/aeiva/storage/postgresql/test.py
 8
 9
10
class RecordNotFoundError(Exception):
    """Exception raised when a record is not found in the database."""
    pass

qdrant

qdrant_config

QdrantConfig dataclass

Bases: BaseConfig

Configuration for Qdrant vector database.

Source code in src/aeiva/storage/qdrant/qdrant_config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@dataclass
class QdrantConfig(BaseConfig):
    """
    Configuration for Qdrant vector database.
    """

    collection_name: str = field(
        default="mem0",
        metadata={"help": "Name of the collection."}
    )
    embedding_model_dims: int = field(
        default=1536,
        metadata={"help": "Dimensions of the embedding model."}
    )
    client: Optional[Any] = field(
        default=None,
        metadata={"help": "Existing Qdrant client instance (if any)."}
    )
    host: Optional[str] = field(
        default=None,
        metadata={"help": "Host address for Qdrant server."}
    )
    port: Optional[int] = field(
        default=None,
        metadata={"help": "Port for Qdrant server."}
    )
    path: Optional[str] = field(
        default=None,
        metadata={"help": "Path for local Qdrant database storage."}
    )
    url: Optional[str] = field(
        default=None,
        metadata={"help": "Full URL for Qdrant server."}
    )
    api_key: Optional[str] = field(
        default=None,
        metadata={"help": "API key for Qdrant server authentication."}
    )
    on_disk: bool = field(
        default=False,
        metadata={"help": "Whether to enable persistent storage on disk."}
    )

    def __post_init__(self):
        super().__post_init__()
        # Validate that connection parameters are provided
        if not self.path and not ((self.host and self.port) or (self.url and self.api_key)):
            raise ValueError("Provide 'path' for local storage, or 'host' and 'port', or 'url' and 'api_key' for remote connection.")

qdrant_database

QdrantDatabase

Bases: VectorDatabase

Concrete implementation of VectorStoreBase using Qdrant.

Source code in src/aeiva/storage/qdrant/qdrant_database.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
class QdrantDatabase(VectorDatabase):
    """
    Concrete implementation of VectorStoreBase using Qdrant.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the Qdrant vector store.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.collection_name = config.get('collection_name')
        self.embedding_model_dims = config.get('embedding_model_dims')
        self.client = config.get('client')
        self.host = config.get('host')
        self.port = config.get('port')
        self.path = config.get('path')
        self.url = config.get('url')
        self.api_key = config.get('api_key')
        self.on_disk = config.get('on_disk', False)

        if not all([self.collection_name, self.embedding_model_dims]):
            raise ValueError("Required configuration parameters are missing.")

        self.create_client()
        self.create_collection(
            collection_name=self.collection_name,
            vector_size=self.embedding_model_dims,
            distance_metric='COSINE'
        )

    def create_client(self, **kwargs) -> None:
        """
        Initializes the client connection to the Qdrant vector store.

        Args:
            **kwargs: Additional parameters.
        """
        if self.client:
            return  # Client already provided

        client_params = {}
        if self.api_key:
            client_params['api_key'] = self.api_key
        if self.url:
            client_params['url'] = self.url
        elif self.host and self.port:
            client_params['host'] = self.host
            client_params['port'] = self.port
        else:
            client_params['path'] = self.path

        self.client = QdrantClient(**client_params)
        logger.info("Qdrant client initialized.")

    def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
        """
        Create a new vector collection in Qdrant.

        Args:
            collection_name (str): The name of the collection.
            vector_size (int): The dimensionality of the vectors.
            distance_metric (str): The distance metric to use (e.g., 'COSINE').
        """
        # Check if collection exists
        collections = self.list_collections()
        if collection_name in collections:
            logger.info(f"Collection {collection_name} already exists. Skipping creation.")
            return

        vector_params = VectorParams(
            size=vector_size,
            distance=getattr(Distance, distance_metric.upper()),
            on_disk=self.on_disk
        )
        self.client.create_collection(
            collection_name=collection_name,
            vectors_config=vector_params
        )
        logger.info(f"Collection {collection_name} created successfully.")

    def insert_vectors(
        self,
        collection_name: str,
        vectors: List[List[float]],
        payloads: Optional[List[Dict[str, Any]]] = None,
        ids: Optional[List[str]] = None
    ) -> None:
        """
        Insert vectors into a collection.

        Args:
            collection_name (str): The name of the collection.
            vectors (List[List[float]]): A list of vectors to insert.
            payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
            ids (Optional[List[str]]): Optional unique identifiers for each vector.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        if ids is None:
            ids = [i for i in range(len(vectors))]
        if payloads is None:
            payloads = [{} for _ in range(len(vectors))]
        if not (len(ids) == len(vectors) == len(payloads)):
            raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

        points = [
            PointStruct(
                id=id_,
                vector=vector,
                payload=payload
            )
            for id_, vector, payload in zip(ids, vectors, payloads)
        ]
        self.client.upsert(
            collection_name=collection_name,
            points=points
        )
        logger.info(f"Inserted {len(vectors)} vectors into collection {collection_name}.")

    def search_vectors(
        self,
        collection_name: str,
        query_vector: List[float],
        top_k: int = 5,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """
        Search for similar vectors in a collection.

        Args:
            collection_name (str): The name of the collection.
            query_vector (List[float]): The vector to search with.
            top_k (int): The number of top results to return.
            filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

        Returns:
            List[Dict[str, Any]]: A list of search results.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        query_filter = self._build_filter(filters)
        results = self.client.search(
            collection_name=collection_name,
            query_vector=query_vector,
            limit=top_k,
            query_filter=query_filter
        )

        output = []
        for hit in results:
            result = {
                'id': hit.id,
                'score': hit.score,
                'payload': hit.payload
            }
            output.append(result)
        return output

    def _build_filter(self, filters: Optional[Dict[str, Any]]) -> Optional[Filter]:
        """
        Build a Qdrant filter object from a dictionary.

        Args:
            filters (Optional[Dict[str, Any]]): Filters to apply.

        Returns:
            Optional[Filter]: A Qdrant Filter object.
        """
        if not filters:
            return None

        conditions = []
        for key, value in filters.items():
            conditions.append(
                FieldCondition(
                    key=key,
                    match=MatchValue(value=value)
                )
            )
        return Filter(must=conditions)

    def delete_vector(self, collection_name: str, vector_id: str) -> None:
        """
        Delete a vector from a collection by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to delete.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        self.client.delete(
            collection_name=collection_name,
            points_selector=[vector_id]
        )
        logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")

    def update_vector(
        self,
        collection_name: str,
        vector_id: str,
        vector: Optional[List[float]] = None,
        payload: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Update a vector's data or payload.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to update.
            vector (Optional[List[float]]): The new vector data.
            payload (Optional[Dict[str, Any]]): The new payload data.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        point = PointStruct(
            id=vector_id,
            vector=vector,
            payload=payload
        )
        self.client.upsert(
            collection_name=collection_name,
            points=[point]
        )
        logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

    def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
        """
        Retrieve a vector by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector.

        Returns:
            Dict[str, Any]: A dictionary containing the vector data and payload.
        """
        if collection_name != self.collection_name:
            raise ValueError("Collection name does not match initialized collection name.")

        result = self.client.retrieve(
            collection_name=collection_name,
            ids=[vector_id]
        )
        if not result:
            raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

        point = result[0]
        vector_data = {
            'id': point.id,
            'vector': point.vector,
            'payload': point.payload
        }
        return vector_data

    def list_collections(self) -> List[str]:
        """
        List all available vector collections.

        Returns:
            List[str]: A list of collection names.
        """
        collections = self.client.get_collections().collections
        return [collection.name for collection in collections]

    def delete_collection(self, collection_name: str) -> None:
        """
        Delete an entire vector collection.

        Args:
            collection_name (str): The name of the collection to delete.
        """
        self.client.delete_collection(collection_name=collection_name)
        logger.info(f"Deleted collection {collection_name}.")

    def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        """
        Get information about a collection.

        Args:
            collection_name (str): The name of the collection.

        Returns:
            Dict[str, Any]: Information about the collection.
        """
        info = self.client.get_collection(collection_name=collection_name)
        return info.dict()
__init__(config)

Initialize the Qdrant vector store.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/qdrant/qdrant_database.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the Qdrant vector store.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.collection_name = config.get('collection_name')
    self.embedding_model_dims = config.get('embedding_model_dims')
    self.client = config.get('client')
    self.host = config.get('host')
    self.port = config.get('port')
    self.path = config.get('path')
    self.url = config.get('url')
    self.api_key = config.get('api_key')
    self.on_disk = config.get('on_disk', False)

    if not all([self.collection_name, self.embedding_model_dims]):
        raise ValueError("Required configuration parameters are missing.")

    self.create_client()
    self.create_collection(
        collection_name=self.collection_name,
        vector_size=self.embedding_model_dims,
        distance_metric='COSINE'
    )
create_client(**kwargs)

Initializes the client connection to the Qdrant vector store.

Parameters:

Name Type Description Default
**kwargs

Additional parameters.

{}
Source code in src/aeiva/storage/qdrant/qdrant_database.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def create_client(self, **kwargs) -> None:
    """
    Initializes the client connection to the Qdrant vector store.

    Args:
        **kwargs: Additional parameters.
    """
    if self.client:
        return  # Client already provided

    client_params = {}
    if self.api_key:
        client_params['api_key'] = self.api_key
    if self.url:
        client_params['url'] = self.url
    elif self.host and self.port:
        client_params['host'] = self.host
        client_params['port'] = self.port
    else:
        client_params['path'] = self.path

    self.client = QdrantClient(**client_params)
    logger.info("Qdrant client initialized.")
create_collection(collection_name, vector_size, distance_metric)

Create a new vector collection in Qdrant.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_size int

The dimensionality of the vectors.

required
distance_metric str

The distance metric to use (e.g., 'COSINE').

required
Source code in src/aeiva/storage/qdrant/qdrant_database.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
    """
    Create a new vector collection in Qdrant.

    Args:
        collection_name (str): The name of the collection.
        vector_size (int): The dimensionality of the vectors.
        distance_metric (str): The distance metric to use (e.g., 'COSINE').
    """
    # Check if collection exists
    collections = self.list_collections()
    if collection_name in collections:
        logger.info(f"Collection {collection_name} already exists. Skipping creation.")
        return

    vector_params = VectorParams(
        size=vector_size,
        distance=getattr(Distance, distance_metric.upper()),
        on_disk=self.on_disk
    )
    self.client.create_collection(
        collection_name=collection_name,
        vectors_config=vector_params
    )
    logger.info(f"Collection {collection_name} created successfully.")
delete_collection(collection_name)

Delete an entire vector collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection to delete.

required
Source code in src/aeiva/storage/qdrant/qdrant_database.py
293
294
295
296
297
298
299
300
301
def delete_collection(self, collection_name: str) -> None:
    """
    Delete an entire vector collection.

    Args:
        collection_name (str): The name of the collection to delete.
    """
    self.client.delete_collection(collection_name=collection_name)
    logger.info(f"Deleted collection {collection_name}.")
delete_vector(collection_name, vector_id)

Delete a vector from a collection by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to delete.

required
Source code in src/aeiva/storage/qdrant/qdrant_database.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def delete_vector(self, collection_name: str, vector_id: str) -> None:
    """
    Delete a vector from a collection by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to delete.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    self.client.delete(
        collection_name=collection_name,
        points_selector=[vector_id]
    )
    logger.info(f"Deleted vector with ID {vector_id} from collection {collection_name}.")
get_collection_info(collection_name)

Get information about a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Information about the collection.

Source code in src/aeiva/storage/qdrant/qdrant_database.py
303
304
305
306
307
308
309
310
311
312
313
314
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
    """
    Get information about a collection.

    Args:
        collection_name (str): The name of the collection.

    Returns:
        Dict[str, Any]: Information about the collection.
    """
    info = self.client.get_collection(collection_name=collection_name)
    return info.dict()
get_vector(collection_name, vector_id)

Retrieve a vector by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the vector data and payload.

Source code in src/aeiva/storage/qdrant/qdrant_database.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
    """
    Retrieve a vector by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector.

    Returns:
        Dict[str, Any]: A dictionary containing the vector data and payload.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    result = self.client.retrieve(
        collection_name=collection_name,
        ids=[vector_id]
    )
    if not result:
        raise KeyError(f"Vector with ID {vector_id} not found in collection {collection_name}.")

    point = result[0]
    vector_data = {
        'id': point.id,
        'vector': point.vector,
        'payload': point.payload
    }
    return vector_data
insert_vectors(collection_name, vectors, payloads=None, ids=None)

Insert vectors into a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vectors List[List[float]]

A list of vectors to insert.

required
payloads Optional[List[Dict[str, Any]]]

Optional metadata associated with each vector.

None
ids Optional[List[str]]

Optional unique identifiers for each vector.

None
Source code in src/aeiva/storage/qdrant/qdrant_database.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def insert_vectors(
    self,
    collection_name: str,
    vectors: List[List[float]],
    payloads: Optional[List[Dict[str, Any]]] = None,
    ids: Optional[List[str]] = None
) -> None:
    """
    Insert vectors into a collection.

    Args:
        collection_name (str): The name of the collection.
        vectors (List[List[float]]): A list of vectors to insert.
        payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
        ids (Optional[List[str]]): Optional unique identifiers for each vector.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    if ids is None:
        ids = [i for i in range(len(vectors))]
    if payloads is None:
        payloads = [{} for _ in range(len(vectors))]
    if not (len(ids) == len(vectors) == len(payloads)):
        raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

    points = [
        PointStruct(
            id=id_,
            vector=vector,
            payload=payload
        )
        for id_, vector, payload in zip(ids, vectors, payloads)
    ]
    self.client.upsert(
        collection_name=collection_name,
        points=points
    )
    logger.info(f"Inserted {len(vectors)} vectors into collection {collection_name}.")
list_collections()

List all available vector collections.

Returns:

Type Description
List[str]

List[str]: A list of collection names.

Source code in src/aeiva/storage/qdrant/qdrant_database.py
283
284
285
286
287
288
289
290
291
def list_collections(self) -> List[str]:
    """
    List all available vector collections.

    Returns:
        List[str]: A list of collection names.
    """
    collections = self.client.get_collections().collections
    return [collection.name for collection in collections]
search_vectors(collection_name, query_vector, top_k=5, filters=None)

Search for similar vectors in a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
query_vector List[float]

The vector to search with.

required
top_k int

The number of top results to return.

5
filters Optional[Dict[str, Any]]

Optional filters to apply to the search.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of search results.

Source code in src/aeiva/storage/qdrant/qdrant_database.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def search_vectors(
    self,
    collection_name: str,
    query_vector: List[float],
    top_k: int = 5,
    filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
    """
    Search for similar vectors in a collection.

    Args:
        collection_name (str): The name of the collection.
        query_vector (List[float]): The vector to search with.
        top_k (int): The number of top results to return.
        filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

    Returns:
        List[Dict[str, Any]]: A list of search results.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    query_filter = self._build_filter(filters)
    results = self.client.search(
        collection_name=collection_name,
        query_vector=query_vector,
        limit=top_k,
        query_filter=query_filter
    )

    output = []
    for hit in results:
        result = {
            'id': hit.id,
            'score': hit.score,
            'payload': hit.payload
        }
        output.append(result)
    return output
update_vector(collection_name, vector_id, vector=None, payload=None)

Update a vector's data or payload.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to update.

required
vector Optional[List[float]]

The new vector data.

None
payload Optional[Dict[str, Any]]

The new payload data.

None
Source code in src/aeiva/storage/qdrant/qdrant_database.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def update_vector(
    self,
    collection_name: str,
    vector_id: str,
    vector: Optional[List[float]] = None,
    payload: Optional[Dict[str, Any]] = None
) -> None:
    """
    Update a vector's data or payload.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to update.
        vector (Optional[List[float]]): The new vector data.
        payload (Optional[Dict[str, Any]]): The new payload data.
    """
    if collection_name != self.collection_name:
        raise ValueError("Collection name does not match initialized collection name.")

    point = PointStruct(
        id=vector_id,
        vector=vector,
        payload=payload
    )
    self.client.upsert(
        collection_name=collection_name,
        points=[point]
    )
    logger.info(f"Updated vector with ID {vector_id} in collection {collection_name}.")

relational_database

RelationalDatabase

Bases: ABC

Abstract base class for relational database operations.

Source code in src/aeiva/storage/relational_database.py
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class RelationalDatabase(ABC):
    """
    Abstract base class for relational database operations.
    """

    @abstractmethod
    def insert_record(self, table: str, record: Dict[str, Any]) -> Any:
        """
        Inserts a record into a table.

        Args:
            table (str): The name of the table.
            record (Dict[str, Any]): A dictionary representing the record to insert.

        Returns:
            Any: The primary key of the inserted record.

        Raises:
            StorageError: If there is an issue inserting the record.
        """
        pass

    @abstractmethod
    def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:
        """
        Retrieves a record by its primary key.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.

        Returns:
            Dict[str, Any]: The retrieved record.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue retrieving the record.
        """
        pass

    @abstractmethod
    def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:
        """
        Updates a record in a table.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.
            updates (Dict[str, Any]): A dictionary of fields to update.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue updating the record.
        """
        pass

    @abstractmethod
    def delete_record(self, table: str, primary_key: Any) -> None:
        """
        Deletes a record from a table.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue deleting the record.
        """
        pass

    @abstractmethod
    def query_records(self, table: str, conditions: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        Queries records from a table based on conditions.

        Args:
            table (str): The name of the table.
            conditions (Optional[Dict[str, Any]]): Conditions to filter records.
            limit (Optional[int]): Maximum number of records to return.
            offset (Optional[int]): Number of records to skip.

        Returns:
            List[Dict[str, Any]]: A list of records matching the query.

        Raises:
            StorageError: If there is an issue querying records.
        """
        pass

    @abstractmethod
    def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:
        """
        Executes a raw SQL query.

        Args:
            query (str): The SQL query string.
            parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

        Returns:
            Any: The result of the query.

        Raises:
            StorageError: If there is an issue executing the query.
        """
        pass

    @abstractmethod
    def begin_transaction(self) -> None:
        """
        Begins a transaction.
        """
        pass

    @abstractmethod
    def commit_transaction(self) -> None:
        """
        Commits the current transaction.
        """
        pass

    @abstractmethod
    def rollback_transaction(self) -> None:
        """
        Rolls back the current transaction.
        """
        pass

    @abstractmethod
    def close(self) -> None:
        """
        Closes the database connection and releases resources.
        """
        pass
begin_transaction() abstractmethod

Begins a transaction.

Source code in src/aeiva/storage/relational_database.py
112
113
114
115
116
117
@abstractmethod
def begin_transaction(self) -> None:
    """
    Begins a transaction.
    """
    pass
close() abstractmethod

Closes the database connection and releases resources.

Source code in src/aeiva/storage/relational_database.py
133
134
135
136
137
138
@abstractmethod
def close(self) -> None:
    """
    Closes the database connection and releases resources.
    """
    pass
commit_transaction() abstractmethod

Commits the current transaction.

Source code in src/aeiva/storage/relational_database.py
119
120
121
122
123
124
@abstractmethod
def commit_transaction(self) -> None:
    """
    Commits the current transaction.
    """
    pass
delete_record(table, primary_key) abstractmethod

Deletes a record from a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue deleting the record.

Source code in src/aeiva/storage/relational_database.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@abstractmethod
def delete_record(self, table: str, primary_key: Any) -> None:
    """
    Deletes a record from a table.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue deleting the record.
    """
    pass
execute_sql(query, parameters=None) abstractmethod

Executes a raw SQL query.

Parameters:

Name Type Description Default
query str

The SQL query string.

required
parameters Optional[Dict[str, Any]]

Parameters for parameterized queries.

None

Returns:

Name Type Description
Any Any

The result of the query.

Raises:

Type Description
StorageError

If there is an issue executing the query.

Source code in src/aeiva/storage/relational_database.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@abstractmethod
def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:
    """
    Executes a raw SQL query.

    Args:
        query (str): The SQL query string.
        parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.

    Returns:
        Any: The result of the query.

    Raises:
        StorageError: If there is an issue executing the query.
    """
    pass
get_record(table, primary_key) abstractmethod

Retrieves a record by its primary key.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: The retrieved record.

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue retrieving the record.

Source code in src/aeiva/storage/relational_database.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@abstractmethod
def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:
    """
    Retrieves a record by its primary key.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.

    Returns:
        Dict[str, Any]: The retrieved record.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue retrieving the record.
    """
    pass
insert_record(table, record) abstractmethod

Inserts a record into a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
record Dict[str, Any]

A dictionary representing the record to insert.

required

Returns:

Name Type Description
Any Any

The primary key of the inserted record.

Raises:

Type Description
StorageError

If there is an issue inserting the record.

Source code in src/aeiva/storage/relational_database.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@abstractmethod
def insert_record(self, table: str, record: Dict[str, Any]) -> Any:
    """
    Inserts a record into a table.

    Args:
        table (str): The name of the table.
        record (Dict[str, Any]): A dictionary representing the record to insert.

    Returns:
        Any: The primary key of the inserted record.

    Raises:
        StorageError: If there is an issue inserting the record.
    """
    pass
query_records(table, conditions=None, limit=None, offset=None) abstractmethod

Queries records from a table based on conditions.

Parameters:

Name Type Description Default
table str

The name of the table.

required
conditions Optional[Dict[str, Any]]

Conditions to filter records.

None
limit Optional[int]

Maximum number of records to return.

None
offset Optional[int]

Number of records to skip.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of records matching the query.

Raises:

Type Description
StorageError

If there is an issue querying records.

Source code in src/aeiva/storage/relational_database.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@abstractmethod
def query_records(self, table: str, conditions: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Dict[str, Any]]:
    """
    Queries records from a table based on conditions.

    Args:
        table (str): The name of the table.
        conditions (Optional[Dict[str, Any]]): Conditions to filter records.
        limit (Optional[int]): Maximum number of records to return.
        offset (Optional[int]): Number of records to skip.

    Returns:
        List[Dict[str, Any]]: A list of records matching the query.

    Raises:
        StorageError: If there is an issue querying records.
    """
    pass
rollback_transaction() abstractmethod

Rolls back the current transaction.

Source code in src/aeiva/storage/relational_database.py
126
127
128
129
130
131
@abstractmethod
def rollback_transaction(self) -> None:
    """
    Rolls back the current transaction.
    """
    pass
update_record(table, primary_key, updates) abstractmethod

Updates a record in a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required
updates Dict[str, Any]

A dictionary of fields to update.

required

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue updating the record.

Source code in src/aeiva/storage/relational_database.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@abstractmethod
def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:
    """
    Updates a record in a table.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.
        updates (Dict[str, Any]): A dictionary of fields to update.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue updating the record.
    """
    pass

sqlite

sqlite_config

SQLiteConfig dataclass

Bases: BaseConfig

Configuration for SQLite database.

Source code in src/aeiva/storage/sqlite/sqlite_config.py
 7
 8
 9
10
11
12
13
14
15
@dataclass
class SQLiteConfig(BaseConfig):
    """
    Configuration for SQLite database.
    """
    database: str = field(
        default=':memory:',
        metadata={"help": "Path to the SQLite database file. Use ':memory:' for an in-memory database."}
    )

sqlite_database

RecordNotFoundError

Bases: Exception

Exception raised when a record is not found in the database.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
 8
 9
10
class RecordNotFoundError(Exception):
    """Exception raised when a record is not found in the database."""
    pass
SQLiteDatabase

Bases: RelationalDatabase

Concrete implementation of RelationalStoreBase using SQLite.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
class SQLiteDatabase(RelationalDatabase):
    """
    Concrete implementation of RelationalStoreBase using SQLite.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the SQLite database connection.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.database = config.get('database', ':memory:')
        self.connection = None
        self.cursor = None
        self.connect()

    def connect(self) -> None:
        """
        Establishes a connection to the SQLite database.
        """
        try:
            self.connection = sqlite3.connect(self.database)
            self.connection.row_factory = sqlite3.Row  # To get dict-like rows
            self.cursor = self.connection.cursor()
            # self.connection.execute('PRAGMA foreign_keys = ON')  # Enable foreign key support
        except sqlite3.Error as e:
            raise ConnectionError(f"Failed to connect to SQLite database: {e}")

    def close(self) -> None:
        """
        Closes the database connection and releases resources.
        """
        if self.cursor:
            self.cursor.close()
        if self.connection:
            self.connection.close()

    def insert_record(self, table: str, record: Dict[str, Any]) -> Any:
        """
        Inserts a record into a table.

        Args:
            table (str): The name of the table.
            record (Dict[str, Any]): A dictionary representing the record to insert.

        Returns:
            Any: The primary key of the inserted record.

        Raises:
            StorageError: If there is an issue inserting the record.
        """
        try:
            columns = ', '.join(record.keys())
            placeholders = ', '.join('?' for _ in record)
            sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
            values = list(record.values())
            self.cursor.execute(sql, values)
            self.connection.commit()
            return self.cursor.lastrowid
        except sqlite3.IntegrityError as e:
            self.connection.rollback()
            raise StorageError(f"Integrity error: {e}")
        except sqlite3.Error as e:
            self.connection.rollback()
            raise StorageError(f"Failed to insert record: {e}")

    def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:
        """
        Retrieves a record by its primary key.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.

        Returns:
            Dict[str, Any]: The retrieved record.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue retrieving the record.
        """
        try:
            sql = f"SELECT * FROM {table} WHERE id = ?"
            self.cursor.execute(sql, (primary_key,))
            row = self.cursor.fetchone()
            if row is None:
                raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
            return dict(row)
        except sqlite3.Error as e:
            raise StorageError(f"Failed to get record: {e}")

    def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:
        """
        Updates a record in a table.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.
            updates (Dict[str, Any]): A dictionary of fields to update.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue updating the record.
        """
        try:
            set_clause = ', '.join(f"{key} = ?" for key in updates.keys())
            sql = f"UPDATE {table} SET {set_clause} WHERE id = ?"
            values = list(updates.values()) + [primary_key]
            self.cursor.execute(sql, values)
            if self.cursor.rowcount == 0:
                self.connection.rollback()
                raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
            self.connection.commit()
        except sqlite3.Error as e:
            self.connection.rollback()
            raise StorageError(f"Failed to update record: {e}")

    def delete_record(self, table: str, primary_key: Any) -> None:
        """
        Deletes a record from a table.

        Args:
            table (str): The name of the table.
            primary_key (Any): The primary key of the record.

        Raises:
            RecordNotFoundError: If the record does not exist.
            StorageError: If there is an issue deleting the record.
        """
        try:
            sql = f"DELETE FROM {table} WHERE id = ?"
            self.cursor.execute(sql, (primary_key,))
            if self.cursor.rowcount == 0:
                self.connection.rollback()
                raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
            self.connection.commit()
        except sqlite3.Error as e:
            self.connection.rollback()
            raise StorageError(f"Failed to delete record: {e}")

    def query_records(
        self,
        table: str,
        conditions: Optional[Dict[str, Any]] = None,
        limit: Optional[int] = None,
        offset: Optional[int] = None
    ) -> List[Dict[str, Any]]:
        """
        Queries records from a table based on conditions.

        Args:
            table (str): The name of the table.
            conditions (Optional[Dict[str, Any]]): Conditions to filter records.
            limit (Optional[int]): Maximum number of records to return.
            offset (Optional[int]): Number of records to skip.

        Returns:
            List[Dict[str, Any]]: A list of records matching the query.

        Raises:
            StorageError: If there is an issue querying records.
        """
        try:
            sql = f"SELECT * FROM {table}"
            params = []
            if conditions:
                where_clause = ' AND '.join(f"{key} = ?" for key in conditions.keys())
                sql += f" WHERE {where_clause}"
                params.extend(conditions.values())
            if limit is not None:
                sql += f" LIMIT {limit}"
            if offset is not None:
                sql += f" OFFSET {offset}"
            self.cursor.execute(sql, params)
            rows = self.cursor.fetchall()
            return [dict(row) for row in rows]
        except sqlite3.Error as e:
            raise StorageError(f"Failed to query records: {e}")

    def execute_sql(self, query: str, params: Optional[Tuple] = None):
        """
        Executes a SQL query and returns the cursor.

        Args:
            query (str): The SQL query to execute.
            params (Optional[Tuple]): Parameters to substitute into the query.

        Returns:
            sqlite3.Cursor: The cursor after executing the query.
        """
        cursor = self.connection.cursor()
        try:
            if params:
                cursor.execute(query, params)
            else:
                cursor.execute(query)
            # For SELECT queries, do not commit. For INSERT/UPDATE/DELETE, you may need to commit.
            if query.strip().upper().startswith("SELECT"):
                return cursor
            else:
                self.connection.commit()
                return cursor
        except sqlite3.Error as e:
            print(f"SQLite query failed: {e}")
            raise e

    def begin_transaction(self) -> None:
        """
        Begins a transaction.
        """
        self.connection.isolation_level = None
        self.cursor.execute('BEGIN')

    def commit_transaction(self) -> None:
        """
        Commits the current transaction.
        """
        self.connection.commit()
        self.connection.isolation_level = None

    def rollback_transaction(self) -> None:
        """
        Rolls back the current transaction.
        """
        self.connection.rollback()
        self.connection.isolation_level = None
__init__(config)

Initialize the SQLite database connection.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/sqlite/sqlite_database.py
23
24
25
26
27
28
29
30
31
32
33
34
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the SQLite database connection.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.database = config.get('database', ':memory:')
    self.connection = None
    self.cursor = None
    self.connect()
begin_transaction()

Begins a transaction.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
226
227
228
229
230
231
def begin_transaction(self) -> None:
    """
    Begins a transaction.
    """
    self.connection.isolation_level = None
    self.cursor.execute('BEGIN')
close()

Closes the database connection and releases resources.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
48
49
50
51
52
53
54
55
def close(self) -> None:
    """
    Closes the database connection and releases resources.
    """
    if self.cursor:
        self.cursor.close()
    if self.connection:
        self.connection.close()
commit_transaction()

Commits the current transaction.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
233
234
235
236
237
238
def commit_transaction(self) -> None:
    """
    Commits the current transaction.
    """
    self.connection.commit()
    self.connection.isolation_level = None
connect()

Establishes a connection to the SQLite database.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
36
37
38
39
40
41
42
43
44
45
46
def connect(self) -> None:
    """
    Establishes a connection to the SQLite database.
    """
    try:
        self.connection = sqlite3.connect(self.database)
        self.connection.row_factory = sqlite3.Row  # To get dict-like rows
        self.cursor = self.connection.cursor()
        # self.connection.execute('PRAGMA foreign_keys = ON')  # Enable foreign key support
    except sqlite3.Error as e:
        raise ConnectionError(f"Failed to connect to SQLite database: {e}")
delete_record(table, primary_key)

Deletes a record from a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue deleting the record.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def delete_record(self, table: str, primary_key: Any) -> None:
    """
    Deletes a record from a table.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue deleting the record.
    """
    try:
        sql = f"DELETE FROM {table} WHERE id = ?"
        self.cursor.execute(sql, (primary_key,))
        if self.cursor.rowcount == 0:
            self.connection.rollback()
            raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
        self.connection.commit()
    except sqlite3.Error as e:
        self.connection.rollback()
        raise StorageError(f"Failed to delete record: {e}")
execute_sql(query, params=None)

Executes a SQL query and returns the cursor.

Parameters:

Name Type Description Default
query str

The SQL query to execute.

required
params Optional[Tuple]

Parameters to substitute into the query.

None

Returns:

Type Description

sqlite3.Cursor: The cursor after executing the query.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def execute_sql(self, query: str, params: Optional[Tuple] = None):
    """
    Executes a SQL query and returns the cursor.

    Args:
        query (str): The SQL query to execute.
        params (Optional[Tuple]): Parameters to substitute into the query.

    Returns:
        sqlite3.Cursor: The cursor after executing the query.
    """
    cursor = self.connection.cursor()
    try:
        if params:
            cursor.execute(query, params)
        else:
            cursor.execute(query)
        # For SELECT queries, do not commit. For INSERT/UPDATE/DELETE, you may need to commit.
        if query.strip().upper().startswith("SELECT"):
            return cursor
        else:
            self.connection.commit()
            return cursor
    except sqlite3.Error as e:
        print(f"SQLite query failed: {e}")
        raise e
get_record(table, primary_key)

Retrieves a record by its primary key.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: The retrieved record.

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue retrieving the record.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:
    """
    Retrieves a record by its primary key.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.

    Returns:
        Dict[str, Any]: The retrieved record.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue retrieving the record.
    """
    try:
        sql = f"SELECT * FROM {table} WHERE id = ?"
        self.cursor.execute(sql, (primary_key,))
        row = self.cursor.fetchone()
        if row is None:
            raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
        return dict(row)
    except sqlite3.Error as e:
        raise StorageError(f"Failed to get record: {e}")
insert_record(table, record)

Inserts a record into a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
record Dict[str, Any]

A dictionary representing the record to insert.

required

Returns:

Name Type Description
Any Any

The primary key of the inserted record.

Raises:

Type Description
StorageError

If there is an issue inserting the record.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def insert_record(self, table: str, record: Dict[str, Any]) -> Any:
    """
    Inserts a record into a table.

    Args:
        table (str): The name of the table.
        record (Dict[str, Any]): A dictionary representing the record to insert.

    Returns:
        Any: The primary key of the inserted record.

    Raises:
        StorageError: If there is an issue inserting the record.
    """
    try:
        columns = ', '.join(record.keys())
        placeholders = ', '.join('?' for _ in record)
        sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
        values = list(record.values())
        self.cursor.execute(sql, values)
        self.connection.commit()
        return self.cursor.lastrowid
    except sqlite3.IntegrityError as e:
        self.connection.rollback()
        raise StorageError(f"Integrity error: {e}")
    except sqlite3.Error as e:
        self.connection.rollback()
        raise StorageError(f"Failed to insert record: {e}")
query_records(table, conditions=None, limit=None, offset=None)

Queries records from a table based on conditions.

Parameters:

Name Type Description Default
table str

The name of the table.

required
conditions Optional[Dict[str, Any]]

Conditions to filter records.

None
limit Optional[int]

Maximum number of records to return.

None
offset Optional[int]

Number of records to skip.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of records matching the query.

Raises:

Type Description
StorageError

If there is an issue querying records.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def query_records(
    self,
    table: str,
    conditions: Optional[Dict[str, Any]] = None,
    limit: Optional[int] = None,
    offset: Optional[int] = None
) -> List[Dict[str, Any]]:
    """
    Queries records from a table based on conditions.

    Args:
        table (str): The name of the table.
        conditions (Optional[Dict[str, Any]]): Conditions to filter records.
        limit (Optional[int]): Maximum number of records to return.
        offset (Optional[int]): Number of records to skip.

    Returns:
        List[Dict[str, Any]]: A list of records matching the query.

    Raises:
        StorageError: If there is an issue querying records.
    """
    try:
        sql = f"SELECT * FROM {table}"
        params = []
        if conditions:
            where_clause = ' AND '.join(f"{key} = ?" for key in conditions.keys())
            sql += f" WHERE {where_clause}"
            params.extend(conditions.values())
        if limit is not None:
            sql += f" LIMIT {limit}"
        if offset is not None:
            sql += f" OFFSET {offset}"
        self.cursor.execute(sql, params)
        rows = self.cursor.fetchall()
        return [dict(row) for row in rows]
    except sqlite3.Error as e:
        raise StorageError(f"Failed to query records: {e}")
rollback_transaction()

Rolls back the current transaction.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
240
241
242
243
244
245
def rollback_transaction(self) -> None:
    """
    Rolls back the current transaction.
    """
    self.connection.rollback()
    self.connection.isolation_level = None
update_record(table, primary_key, updates)

Updates a record in a table.

Parameters:

Name Type Description Default
table str

The name of the table.

required
primary_key Any

The primary key of the record.

required
updates Dict[str, Any]

A dictionary of fields to update.

required

Raises:

Type Description
RecordNotFoundError

If the record does not exist.

StorageError

If there is an issue updating the record.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:
    """
    Updates a record in a table.

    Args:
        table (str): The name of the table.
        primary_key (Any): The primary key of the record.
        updates (Dict[str, Any]): A dictionary of fields to update.

    Raises:
        RecordNotFoundError: If the record does not exist.
        StorageError: If there is an issue updating the record.
    """
    try:
        set_clause = ', '.join(f"{key} = ?" for key in updates.keys())
        sql = f"UPDATE {table} SET {set_clause} WHERE id = ?"
        values = list(updates.values()) + [primary_key]
        self.cursor.execute(sql, values)
        if self.cursor.rowcount == 0:
            self.connection.rollback()
            raise RecordNotFoundError(f"Record with primary key {primary_key} not found in table '{table}'.")
        self.connection.commit()
    except sqlite3.Error as e:
        self.connection.rollback()
        raise StorageError(f"Failed to update record: {e}")
StorageError

Bases: Exception

Exception raised when there is a storage-related error in the database.

Source code in src/aeiva/storage/sqlite/sqlite_database.py
13
14
15
class StorageError(Exception):
    """Exception raised when there is a storage-related error in the database."""
    pass

test

RecordNotFoundError

Bases: Exception

Exception raised when a record is not found in the database.

Source code in src/aeiva/storage/sqlite/test.py
 9
10
11
class RecordNotFoundError(Exception):
    """Exception raised when a record is not found in the database."""
    pass

test

main()

Main function to run tests for Milvus, Neo4j, and SQLite databases.

Source code in src/aeiva/storage/test.py
179
180
181
182
183
184
185
def main():
    """
    Main function to run tests for Milvus, Neo4j, and SQLite databases.
    """
    test_milvus()
    test_neo4j()
    test_sqlite()

test_milvus()

Test the DatabaseFactory and DatabaseConfigFactory with Milvus database.

Source code in src/aeiva/storage/test.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def test_milvus():
    """
    Test the DatabaseFactory and DatabaseConfigFactory with Milvus database.
    """
    print("\n--- Testing Milvus Database ---")
    # Create configuration for Milvus
    milvus_config = DatabaseConfigFactory.create(
        'milvus',
        # uri='tcp://localhost:19530',
        uri='storage/milvus_demo.db',
        collection_name='test_collection',
        embedding_model_dims=128,
        metric_type='COSINE',
    )

    # Create Milvus database instance
    milvus_db = DatabaseFactory.create('milvus', milvus_config)

    try:
        # Prepare sample data
        vector_dimension = milvus_config.embedding_model_dims
        vectors = [
            [float(i) for i in range(vector_dimension)],  # Sample vector 1
            [float(i + 1) for i in range(vector_dimension)],  # Sample vector 2
        ]
        payloads = [
            {'name': 'Vector 1', 'description': 'First test vector.'},
            {'name': 'Vector 2', 'description': 'Second test vector.'},
        ]
        ids = [str(uuid.uuid4()), str(uuid.uuid4())]  # Generate unique IDs

        # Insert vectors into the collection
        milvus_db.insert_vectors(
            collection_name=milvus_config.collection_name,
            vectors=vectors,
            payloads=payloads,
            ids=ids
        )
        logging.info(f"Inserted vectors with IDs: {ids}")

        # Search for similar vectors
        query_vector = [float(i + 0.5) for i in range(vector_dimension)]  # Query vector
        search_results = milvus_db.search_vectors(
            collection_name=milvus_config.collection_name,
            query_vector=query_vector,
            top_k=2
        )
        print(f"Milvus Search results:\n{search_results}")

    except Exception as e:
        logging.error(f"An error occurred while testing Milvus: {e}")
    finally:
        # Close the connection
        del milvus_db

test_neo4j()

Test the DatabaseFactory and DatabaseConfigFactory with Neo4j database.

Source code in src/aeiva/storage/test.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def test_neo4j():
    """
    Test the DatabaseFactory and DatabaseConfigFactory with Neo4j database.
    """
    print("\n--- Testing Neo4j Database ---")
    # Create configuration for Neo4j
    neo4j_config = DatabaseConfigFactory.create(
        'neo4j',
        uri='bolt://localhost:7687',
        user='neo4j',
        password='cf57bwP9pcdcEK3',  # Replace with your actual password
        database='neo4j',
        encrypted=False,
    )

    # Create Neo4j database instance
    neo4j_db = DatabaseFactory.create('neo4j', neo4j_config)

    try:
        # Add a node
        node_id = 'node1'
        neo4j_db.add_node(
            node_id=node_id,
            properties={'name': 'Alice', 'age': 30},
            labels=['Person']
        )
        logging.info(f"Added node with ID: {node_id}")

        # Retrieve the node
        node_data = neo4j_db.get_node(node_id)
        print(f"Neo4j Node data: {node_data}")

        # Add another node and create a relationship
        node_id2 = 'node2'
        neo4j_db.add_node(
            node_id=node_id2,
            properties={'name': 'Bob', 'age': 25},
            labels=['Person']
        )
        neo4j_db.add_edge(
            source_id=node_id,
            target_id=node_id2,
            relationship='KNOWS',
            properties={'since': 2020}
        )
        logging.info(f"Added edge between {node_id} and {node_id2}")

        # Get neighbors
        neighbors = neo4j_db.get_neighbors(node_id, relationship='KNOWS', direction='out')
        print(f"Neo4j Neighbors of {node_id}: {neighbors}")

    except Exception as e:
        logging.error(f"An error occurred while testing Neo4j: {e}")
    finally:
        # Close the connection
        neo4j_db.close()

test_sqlite()

Test the DatabaseFactory and DatabaseConfigFactory with SQLite database.

Source code in src/aeiva/storage/test.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def test_sqlite():
    """
    Test the DatabaseFactory and DatabaseConfigFactory with SQLite database.
    """
    print("\n--- Testing SQLite Database ---")
    # Create configuration for SQLite
    sqlite_config = DatabaseConfigFactory.create(
        'sqlite',
        database='storage/test_database.db'  # Use a file-based database for persistence
    )

    # Create SQLite database instance
    sqlite_db = DatabaseFactory.create('sqlite', sqlite_config)

    try:
        # Create a sample table
        create_table_sql = """
        CREATE TABLE IF NOT EXISTS users (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            age INTEGER,
            email TEXT UNIQUE
        );
        """
        sqlite_db.execute_sql(create_table_sql)
        logging.info("Created table 'users' in SQLite database.")

        # Insert a record
        record = {'name': 'Alice', 'age': 30, 'email': 'alice@example.com'}
        user_id = sqlite_db.insert_record('users', record)
        logging.info(f"Inserted user with ID: {user_id}")

        # Retrieve the record
        retrieved_record = sqlite_db.get_record('users', user_id)
        print(f"SQLite Retrieved record: {retrieved_record}")

        # Update the record
        updates = {'age': 31}
        sqlite_db.update_record('users', user_id, updates)
        logging.info(f"Updated user with ID: {user_id}")

        # Query records
        conditions = {'age': 31}
        users = sqlite_db.query_records('users', conditions)
        print(f"SQLite Users with age 31: {users}")

    except Exception as e:
        logging.error(f"An error occurred while testing SQLite: {e}")
    finally:
        # Close the database connection
        sqlite_db.close()

vector_database

VectorDatabase

Bases: ABC

Abstract base class for vector storage operations.

Source code in src/aeiva/storage/vector_database.py
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class VectorDatabase(ABC):
    """
    Abstract base class for vector storage operations.
    """

    @abstractmethod
    def create_client(
        self,
        uri: str,
        user: Optional[str] = None,
        password: Optional[str] = None,
        db_name: Optional[str] = None,
        token: Optional[str] = None,
        timeout: Optional[float] = None,
        **kwargs
    ) -> None:
        """
        Initializes the client connection to the vector store.

        Args:
            uri (str): The URI of the vector store instance.
            user (Optional[str]): Username for authentication.
            password (Optional[str]): Password for authentication.
            db_name (Optional[str]): Name of the database.
            token (Optional[str]): Access token for authentication.
            timeout (Optional[float]): Timeout duration for operations.
            **kwargs: Additional implementation-specific parameters.

        Raises:
            ConnectionError: If the client fails to connect to the vector store.
        """
        pass

    @abstractmethod
    def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
        """
        Create a new vector collection.

        Args:
            collection_name (str): The name of the collection.
            vector_size (int): The dimensionality of the vectors.
            distance_metric (str): The distance metric to use (e.g., 'euclidean', 'cosine').

        Raises:
            CollectionAlreadyExistsError: If a collection with the given name already exists.
            StorageError: If there is an issue creating the collection.
        """
        pass

    @abstractmethod
    def insert_vectors(self, collection_name: str, vectors: List[List[float]], payloads: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> None:
        """
        Insert vectors into a collection.

        Args:
            collection_name (str): The name of the collection.
            vectors (List[List[float]]): A list of vectors to insert.
            payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
            ids (Optional[List[str]]): Optional unique identifiers for each vector.

        Raises:
            CollectionNotFoundError: If the specified collection does not exist.
            StorageError: If there is an issue inserting the vectors.
        """
        pass

    @abstractmethod
    def search_vectors(self, collection_name: str, query_vector: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
        """
        Search for similar vectors in a collection.

        Args:
            collection_name (str): The name of the collection.
            query_vector (List[float]): The vector to search with.
            top_k (int): The number of top results to return.
            filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

        Returns:
            List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.

        Raises:
            CollectionNotFoundError: If the specified collection does not exist.
            StorageError: If there is an issue performing the search.
        """
        pass

    @abstractmethod
    def delete_vector(self, collection_name: str, vector_id: str) -> None:
        """
        Delete a vector from a collection by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to delete.

        Raises:
            CollectionNotFoundError: If the specified collection does not exist.
            VectorNotFoundError: If the vector with the specified ID does not exist.
            StorageError: If there is an issue deleting the vector.
        """
        pass

    @abstractmethod
    def update_vector(self, collection_name: str, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None) -> None:
        """
        Update a vector's data or payload.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector to update.
            vector (Optional[List[float]]): The new vector data.
            payload (Optional[Dict[str, Any]]): The new payload data.

        Raises:
            CollectionNotFoundError: If the specified collection does not exist.
            VectorNotFoundError: If the vector with the specified ID does not exist.
            StorageError: If there is an issue updating the vector.
        """
        pass

    @abstractmethod
    def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
        """
        Retrieve a vector by its ID.

        Args:
            collection_name (str): The name of the collection.
            vector_id (str): The unique identifier of the vector.

        Returns:
            Dict[str, Any]: A dictionary containing the vector data and payload.

        Raises:
            CollectionNotFoundError: If the specified collection does not exist.
            VectorNotFoundError: If the vector with the specified ID does not exist.
            StorageError: If there is an issue retrieving the vector.
        """
        pass

    @abstractmethod
    def list_collections(self) -> List[str]:
        """
        List all available vector collections.

        Returns:
            List[str]: A list of collection names.

        Raises:
            StorageError: If there is an issue retrieving the collection list.
        """
        pass

    @abstractmethod
    def delete_collection(self, collection_name: str) -> None:
        """
        Delete an entire vector collection.

        Args:
            collection_name (str): The name of the collection to delete.

        Raises:
            CollectionNotFoundError: If the specified collection does not exist.
            StorageError: If there is an issue deleting the collection.
        """
        pass

    @abstractmethod
    def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        """
        Get information about a collection.

        Args:
            collection_name (str): The name of the collection.

        Returns:
            Dict[str, Any]: Information about the collection, such as vector size and distance metric.

        Raises:
            CollectionNotFoundError: If the specified collection does not exist.
            StorageError: If there is an issue retrieving the collection information.
        """
        pass
create_client(uri, user=None, password=None, db_name=None, token=None, timeout=None, **kwargs) abstractmethod

Initializes the client connection to the vector store.

Parameters:

Name Type Description Default
uri str

The URI of the vector store instance.

required
user Optional[str]

Username for authentication.

None
password Optional[str]

Password for authentication.

None
db_name Optional[str]

Name of the database.

None
token Optional[str]

Access token for authentication.

None
timeout Optional[float]

Timeout duration for operations.

None
**kwargs

Additional implementation-specific parameters.

{}

Raises:

Type Description
ConnectionError

If the client fails to connect to the vector store.

Source code in src/aeiva/storage/vector_database.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@abstractmethod
def create_client(
    self,
    uri: str,
    user: Optional[str] = None,
    password: Optional[str] = None,
    db_name: Optional[str] = None,
    token: Optional[str] = None,
    timeout: Optional[float] = None,
    **kwargs
) -> None:
    """
    Initializes the client connection to the vector store.

    Args:
        uri (str): The URI of the vector store instance.
        user (Optional[str]): Username for authentication.
        password (Optional[str]): Password for authentication.
        db_name (Optional[str]): Name of the database.
        token (Optional[str]): Access token for authentication.
        timeout (Optional[float]): Timeout duration for operations.
        **kwargs: Additional implementation-specific parameters.

    Raises:
        ConnectionError: If the client fails to connect to the vector store.
    """
    pass
create_collection(collection_name, vector_size, distance_metric) abstractmethod

Create a new vector collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_size int

The dimensionality of the vectors.

required
distance_metric str

The distance metric to use (e.g., 'euclidean', 'cosine').

required

Raises:

Type Description
CollectionAlreadyExistsError

If a collection with the given name already exists.

StorageError

If there is an issue creating the collection.

Source code in src/aeiva/storage/vector_database.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@abstractmethod
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:
    """
    Create a new vector collection.

    Args:
        collection_name (str): The name of the collection.
        vector_size (int): The dimensionality of the vectors.
        distance_metric (str): The distance metric to use (e.g., 'euclidean', 'cosine').

    Raises:
        CollectionAlreadyExistsError: If a collection with the given name already exists.
        StorageError: If there is an issue creating the collection.
    """
    pass
delete_collection(collection_name) abstractmethod

Delete an entire vector collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection to delete.

required

Raises:

Type Description
CollectionNotFoundError

If the specified collection does not exist.

StorageError

If there is an issue deleting the collection.

Source code in src/aeiva/storage/vector_database.py
156
157
158
159
160
161
162
163
164
165
166
167
168
@abstractmethod
def delete_collection(self, collection_name: str) -> None:
    """
    Delete an entire vector collection.

    Args:
        collection_name (str): The name of the collection to delete.

    Raises:
        CollectionNotFoundError: If the specified collection does not exist.
        StorageError: If there is an issue deleting the collection.
    """
    pass
delete_vector(collection_name, vector_id) abstractmethod

Delete a vector from a collection by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to delete.

required

Raises:

Type Description
CollectionNotFoundError

If the specified collection does not exist.

VectorNotFoundError

If the vector with the specified ID does not exist.

StorageError

If there is an issue deleting the vector.

Source code in src/aeiva/storage/vector_database.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@abstractmethod
def delete_vector(self, collection_name: str, vector_id: str) -> None:
    """
    Delete a vector from a collection by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to delete.

    Raises:
        CollectionNotFoundError: If the specified collection does not exist.
        VectorNotFoundError: If the vector with the specified ID does not exist.
        StorageError: If there is an issue deleting the vector.
    """
    pass
get_collection_info(collection_name) abstractmethod

Get information about a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Information about the collection, such as vector size and distance metric.

Raises:

Type Description
CollectionNotFoundError

If the specified collection does not exist.

StorageError

If there is an issue retrieving the collection information.

Source code in src/aeiva/storage/vector_database.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
@abstractmethod
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
    """
    Get information about a collection.

    Args:
        collection_name (str): The name of the collection.

    Returns:
        Dict[str, Any]: Information about the collection, such as vector size and distance metric.

    Raises:
        CollectionNotFoundError: If the specified collection does not exist.
        StorageError: If there is an issue retrieving the collection information.
    """
    pass
get_vector(collection_name, vector_id) abstractmethod

Retrieve a vector by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the vector data and payload.

Raises:

Type Description
CollectionNotFoundError

If the specified collection does not exist.

VectorNotFoundError

If the vector with the specified ID does not exist.

StorageError

If there is an issue retrieving the vector.

Source code in src/aeiva/storage/vector_database.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@abstractmethod
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
    """
    Retrieve a vector by its ID.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector.

    Returns:
        Dict[str, Any]: A dictionary containing the vector data and payload.

    Raises:
        CollectionNotFoundError: If the specified collection does not exist.
        VectorNotFoundError: If the vector with the specified ID does not exist.
        StorageError: If there is an issue retrieving the vector.
    """
    pass
insert_vectors(collection_name, vectors, payloads=None, ids=None) abstractmethod

Insert vectors into a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vectors List[List[float]]

A list of vectors to insert.

required
payloads Optional[List[Dict[str, Any]]]

Optional metadata associated with each vector.

None
ids Optional[List[str]]

Optional unique identifiers for each vector.

None

Raises:

Type Description
CollectionNotFoundError

If the specified collection does not exist.

StorageError

If there is an issue inserting the vectors.

Source code in src/aeiva/storage/vector_database.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@abstractmethod
def insert_vectors(self, collection_name: str, vectors: List[List[float]], payloads: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> None:
    """
    Insert vectors into a collection.

    Args:
        collection_name (str): The name of the collection.
        vectors (List[List[float]]): A list of vectors to insert.
        payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
        ids (Optional[List[str]]): Optional unique identifiers for each vector.

    Raises:
        CollectionNotFoundError: If the specified collection does not exist.
        StorageError: If there is an issue inserting the vectors.
    """
    pass
list_collections() abstractmethod

List all available vector collections.

Returns:

Type Description
List[str]

List[str]: A list of collection names.

Raises:

Type Description
StorageError

If there is an issue retrieving the collection list.

Source code in src/aeiva/storage/vector_database.py
143
144
145
146
147
148
149
150
151
152
153
154
@abstractmethod
def list_collections(self) -> List[str]:
    """
    List all available vector collections.

    Returns:
        List[str]: A list of collection names.

    Raises:
        StorageError: If there is an issue retrieving the collection list.
    """
    pass
search_vectors(collection_name, query_vector, top_k=5, filters=None) abstractmethod

Search for similar vectors in a collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
query_vector List[float]

The vector to search with.

required
top_k int

The number of top results to return.

5
filters Optional[Dict[str, Any]]

Optional filters to apply to the search.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.

Raises:

Type Description
CollectionNotFoundError

If the specified collection does not exist.

StorageError

If there is an issue performing the search.

Source code in src/aeiva/storage/vector_database.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@abstractmethod
def search_vectors(self, collection_name: str, query_vector: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
    """
    Search for similar vectors in a collection.

    Args:
        collection_name (str): The name of the collection.
        query_vector (List[float]): The vector to search with.
        top_k (int): The number of top results to return.
        filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

    Returns:
        List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.

    Raises:
        CollectionNotFoundError: If the specified collection does not exist.
        StorageError: If there is an issue performing the search.
    """
    pass
update_vector(collection_name, vector_id, vector=None, payload=None) abstractmethod

Update a vector's data or payload.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

required
vector_id str

The unique identifier of the vector to update.

required
vector Optional[List[float]]

The new vector data.

None
payload Optional[Dict[str, Any]]

The new payload data.

None

Raises:

Type Description
CollectionNotFoundError

If the specified collection does not exist.

VectorNotFoundError

If the vector with the specified ID does not exist.

StorageError

If there is an issue updating the vector.

Source code in src/aeiva/storage/vector_database.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@abstractmethod
def update_vector(self, collection_name: str, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None) -> None:
    """
    Update a vector's data or payload.

    Args:
        collection_name (str): The name of the collection.
        vector_id (str): The unique identifier of the vector to update.
        vector (Optional[List[float]]): The new vector data.
        payload (Optional[Dict[str, Any]]): The new payload data.

    Raises:
        CollectionNotFoundError: If the specified collection does not exist.
        VectorNotFoundError: If the vector with the specified ID does not exist.
        StorageError: If there is an issue updating the vector.
    """
    pass

weaviate

weaviate_config

WeaviateConfig dataclass

Bases: BaseConfig

Configuration for Weaviate vector database.

Source code in src/aeiva/storage/weaviate/weaviate_config.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass
class WeaviateConfig(BaseConfig):
    """
    Configuration for Weaviate vector database.
    """

    url: str = field(
        default='http://localhost:8080',
        metadata={"help": "URL of the Weaviate instance (e.g., 'http://localhost:8080')."}
    )
    api_key: Optional[str] = field(
        default=None,
        metadata={"help": "API key for Weaviate authentication (if required)."}
    )
    auth_client_secret: Optional[Dict[str, Any]] = field(
        default=None,
        metadata={"help": "Authentication client secret for Weaviate (if using OIDC)."}
    )
    timeout_config: Optional[Tuple[float, float]] = field(
        default=(2, 20),
        metadata={"help": "Timeout configuration for requests (connect timeout, read timeout)."}
    )
    additional_headers: Optional[Dict[str, str]] = field(
        default=None,
        metadata={"help": "Additional headers to include in requests to Weaviate."}
    )
    embedding_model: Optional[str] = field(
        default=None,
        metadata={"help": "Name of the embedding model used (if required)."}
    )
    index_name: str = field(
        default='MyIndex',
        metadata={"help": "Name of the Weaviate index (class)."}
    )
    vector_dim: int = field(
        default=512,
        metadata={"help": "Dimensionality of the vectors stored in Weaviate."}
    )
    distance_metric: str = field(
        default='cosine',
        metadata={"help": "Distance metric to use (e.g., 'cosine', 'l2-squared', 'dot')."}
    )

    def __post_init__(self):
        super().__post_init__()
        if not self.url:
            raise ValueError("The 'url' parameter is required for Weaviate configuration.")

weaviate_database

WeaviateDatabase

Bases: VectorDatabase

Concrete implementation of VectorStoreBase using Weaviate.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
class WeaviateDatabase(VectorDatabase):
    """
    Concrete implementation of VectorStoreBase using Weaviate.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        """
        Initialize the Weaviate vector store.

        Args:
            config (Dict[str, Any]): Configuration dictionary.
        """
        self.config = config
        self.url = config.get('url', 'http://localhost:8080')
        self.api_key = config.get('api_key')
        self.auth_client_secret = config.get('auth_client_secret')
        self.timeout_config = config.get('timeout_config', (2, 20))
        self.additional_headers = config.get('additional_headers')
        self.embedding_model = config.get('embedding_model')
        self.index_name = config.get('index_name', 'MyIndex')
        self.vector_dim = config.get('vector_dim', 512)
        self.distance_metric = config.get('distance_metric', 'cosine')

        self.client = self.create_client()
        self.create_index(
            index_name=self.index_name,
            vector_dim=self.vector_dim,
            distance_metric=self.distance_metric
        )

    def create_client(self) -> Client:
        """
        Initializes the client connection to the Weaviate vector store.

        Returns:
            Client: The Weaviate client instance.

        Raises:
            ConnectionError: If the client fails to connect to the Weaviate instance.
        """
        try:
            if self.api_key:
                auth_config = AuthApiKey(api_key=self.api_key)
            elif self.auth_client_secret:
                auth_config = AuthClientPassword(**self.auth_client_secret)
            else:
                auth_config = None

            client = weaviate.Client(
                url=self.url,
                auth_client_secret=auth_config,
                timeout_config=self.timeout_config,
                additional_headers=self.additional_headers
            )

            if not client.is_ready():
                raise ConnectionError(f"Weaviate at {self.url} is not ready.")

            logger.info(f"Connected to Weaviate at {self.url}.")
            return client
        except Exception as e:
            logger.error(f"Failed to connect to Weaviate: {e}")
            raise ConnectionError(f"Failed to connect to Weaviate: {e}")

    def create_index(self, index_name: str, vector_dim: int, distance_metric: str) -> None:
        """
        Create a new index (class) in Weaviate.

        Args:
            index_name (str): The name of the index.
            vector_dim (int): The dimensionality of the vectors.
            distance_metric (str): The distance metric to use.

        Raises:
            WeaviateException: If there is an issue creating the index.
        """
        try:
            if self.client.schema.contains(index_name):
                logger.info(f"Index {index_name} already exists. Skipping creation.")
                return

            class_obj = {
                "class": index_name,
                "vectorizer": "none",
                "vectorIndexType": "hnsw",
                "vectorIndexConfig": {
                    "distance": distance_metric
                },
                "properties": [
                    {
                        "name": "id",
                        "dataType": ["string"],
                        "description": "Unique identifier",
                    },
                    {
                        "name": "payload",
                        "dataType": ["blob"],
                        "description": "Payload data",
                    },
                ]
            }

            self.client.schema.create_class(class_obj)
            logger.info(f"Index {index_name} created successfully.")
        except WeaviateException as e:
            logger.error(f"Failed to create index: {e}")
            raise

    def insert_vectors(
        self,
        collection_name: str,
        vectors: List[List[float]],
        payloads: Optional[List[Dict[str, Any]]] = None,
        ids: Optional[List[str]] = None
    ) -> None:
        """
        Insert vectors into the collection.

        Args:
            collection_name (str): The name of the collection (index).
            vectors (List[List[float]]): A list of vectors to insert.
            payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
            ids (Optional[List[str]]): Optional unique identifiers for each vector.

        Raises:
            ValueError: If input data is invalid.
            WeaviateException: If there is an issue inserting vectors.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        if ids is None:
            raise ValueError("Weaviate requires IDs to be provided for each vector.")

        if payloads is None:
            payloads = [{} for _ in range(len(vectors))]

        if not (len(ids) == len(vectors) == len(payloads)):
            raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

        try:
            with self.client.batch(batch_size=100) as batch:
                for id_, vector, payload in zip(ids, vectors, payloads):
                    data_object = {
                        "id": id_,
                        "payload": payload
                    }
                    batch.add_data_object(
                        data_object=data_object,
                        class_name=collection_name,
                        vector=vector
                    )
            logger.info(f"Inserted {len(vectors)} vectors into index {collection_name}.")
        except WeaviateException as e:
            logger.error(f"Failed to insert vectors: {e}")
            raise

    def search_vectors(
        self,
        collection_name: str,
        query_vector: List[float],
        top_k: int = 5,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """
        Search for similar vectors in the collection.

        Args:
            collection_name (str): The name of the collection (index).
            query_vector (List[float]): The vector to search with.
            top_k (int): The number of top results to return.
            filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

        Returns:
            List[Dict[str, Any]]: A list of search results.

        Raises:
            ValueError: If collection name does not match.
            WeaviateException: If there is an issue performing the search.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        try:
            near_vector = {
                "vector": query_vector,
            }

            where_filter = self._build_filters(filters)

            result = self.client.query.get(
                class_name=collection_name,
                properties=["id", "payload"]
            ).with_near_vector(near_vector).with_where(where_filter).with_limit(top_k).do()

            output = []
            for item in result["data"]["Get"][collection_name]:
                result_item = {
                    "id": item["id"],
                    "score": item["_additional"]["certainty"],  # or distance
                    "payload": item["payload"]
                }
                output.append(result_item)
            return output
        except WeaviateException as e:
            logger.error(f"Failed to search vectors: {e}")
            raise

    def _build_filters(self, filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
        """
        Build a Weaviate where filter from a dictionary.

        Args:
            filters (Optional[Dict[str, Any]]): Filters to apply.

        Returns:
            Optional[Dict[str, Any]]: A Weaviate where filter.
        """
        if not filters:
            return None

        conditions = []
        for key, value in filters.items():
            condition = {
                "path": [key],
                "operator": "Equal",
                "valueString": value if isinstance(value, str) else None,
                "valueInt": value if isinstance(value, int) else None,
                "valueBoolean": value if isinstance(value, bool) else None,
                "valueNumber": value if isinstance(value, float) else None,
            }
            conditions.append(condition)

        where_filter = {
            "operator": "And",
            "operands": conditions
        }

        return where_filter

    def delete_vector(self, collection_name: str, vector_id: str) -> None:
        """
        Delete a vector from the collection by its ID.

        Args:
            collection_name (str): The name of the collection (index).
            vector_id (str): The unique identifier of the vector to delete.

        Raises:
            ValueError: If collection name does not match.
            WeaviateException: If there is an issue deleting the vector.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        try:
            self.client.data_object.delete(
                uuid=vector_id,
                class_name=collection_name
            )
            logger.info(f"Deleted vector with ID {vector_id} from index {collection_name}.")
        except WeaviateException as e:
            logger.error(f"Failed to delete vector: {e}")
            raise

    def update_vector(
        self,
        collection_name: str,
        vector_id: str,
        vector: Optional[List[float]] = None,
        payload: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Update a vector's data or payload.

        Args:
            collection_name (str): The name of the collection (index).
            vector_id (str): The unique identifier of the vector to update.
            vector (Optional[List[float]]): The new vector data.
            payload (Optional[Dict[str, Any]]): The new payload data.

        Raises:
            ValueError: If collection name does not match.
            WeaviateException: If there is an issue updating the vector.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        try:
            data_object = {}
            if payload is not None:
                data_object["payload"] = payload

            self.client.data_object.update(
                data_object=data_object,
                class_name=collection_name,
                uuid=vector_id,
                vector=vector
            )
            logger.info(f"Updated vector with ID {vector_id} in index {collection_name}.")
        except WeaviateException as e:
            logger.error(f"Failed to update vector: {e}")
            raise

    def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
        """
        Retrieve a vector by its ID.

        Args:
            collection_name (str): The name of the collection (index).
            vector_id (str): The unique identifier of the vector.

        Returns:
            Dict[str, Any]: A dictionary containing the vector data and payload.

        Raises:
            ValueError: If collection name does not match.
            KeyError: If the vector is not found.
            WeaviateException: If there is an issue retrieving the vector.
        """
        if collection_name != self.index_name:
            raise ValueError("Collection name does not match initialized index name.")

        try:
            result = self.client.data_object.get_by_id(
                uuid=vector_id,
                class_name=collection_name,
                additional_properties=["vector"]
            )
            if result is None:
                raise KeyError(f"Vector with ID {vector_id} not found in index {collection_name}.")

            vector_data = {
                "id": result["id"],
                "vector": result["vector"],
                "payload": result["payload"]
            }
            return vector_data
        except WeaviateException as e:
            logger.error(f"Failed to retrieve vector: {e}")
            raise

    def list_collections(self) -> List[str]:
        """
        List all available indexes (classes).

        Returns:
            List[str]: A list of index names.
        """
        try:
            schema = self.client.schema.get()
            return [clazz["class"] for clazz in schema["classes"]]
        except WeaviateException as e:
            logger.error(f"Failed to list collections: {e}")
            raise

    def delete_collection(self, collection_name: str) -> None:
        """
        Delete an entire index (class).

        Args:
            collection_name (str): The name of the collection (index) to delete.

        Raises:
            WeaviateException: If there is an issue deleting the collection.
        """
        try:
            self.client.schema.delete_class(collection_name)
            logger.info(f"Deleted index {collection_name}.")
        except WeaviateException as e:
            logger.error(f"Failed to delete collection: {e}")
            raise

    def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
        """
        Get information about a collection (index).

        Args:
            collection_name (str): The name of the collection (index).

        Returns:
            Dict[str, Any]: Information about the collection.

        Raises:
            WeaviateException: If there is an issue retrieving the collection info.
        """
        try:
            class_schema = self.client.schema.get(class_name=collection_name)
            return class_schema
        except WeaviateException as e:
            logger.error(f"Failed to get collection info: {e}")
            raise

    def __del__(self):
        """Clean up resources."""
        if hasattr(self, 'client'):
            self.client.close()
            logger.info("Closed connection to Weaviate.")
__del__()

Clean up resources.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
412
413
414
415
416
def __del__(self):
    """Clean up resources."""
    if hasattr(self, 'client'):
        self.client.close()
        logger.info("Closed connection to Weaviate.")
__init__(config)

Initialize the Weaviate vector store.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary.

required
Source code in src/aeiva/storage/weaviate/weaviate_database.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, config: Dict[str, Any]) -> None:
    """
    Initialize the Weaviate vector store.

    Args:
        config (Dict[str, Any]): Configuration dictionary.
    """
    self.config = config
    self.url = config.get('url', 'http://localhost:8080')
    self.api_key = config.get('api_key')
    self.auth_client_secret = config.get('auth_client_secret')
    self.timeout_config = config.get('timeout_config', (2, 20))
    self.additional_headers = config.get('additional_headers')
    self.embedding_model = config.get('embedding_model')
    self.index_name = config.get('index_name', 'MyIndex')
    self.vector_dim = config.get('vector_dim', 512)
    self.distance_metric = config.get('distance_metric', 'cosine')

    self.client = self.create_client()
    self.create_index(
        index_name=self.index_name,
        vector_dim=self.vector_dim,
        distance_metric=self.distance_metric
    )
create_client()

Initializes the client connection to the Weaviate vector store.

Returns:

Name Type Description
Client Client

The Weaviate client instance.

Raises:

Type Description
ConnectionError

If the client fails to connect to the Weaviate instance.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def create_client(self) -> Client:
    """
    Initializes the client connection to the Weaviate vector store.

    Returns:
        Client: The Weaviate client instance.

    Raises:
        ConnectionError: If the client fails to connect to the Weaviate instance.
    """
    try:
        if self.api_key:
            auth_config = AuthApiKey(api_key=self.api_key)
        elif self.auth_client_secret:
            auth_config = AuthClientPassword(**self.auth_client_secret)
        else:
            auth_config = None

        client = weaviate.Client(
            url=self.url,
            auth_client_secret=auth_config,
            timeout_config=self.timeout_config,
            additional_headers=self.additional_headers
        )

        if not client.is_ready():
            raise ConnectionError(f"Weaviate at {self.url} is not ready.")

        logger.info(f"Connected to Weaviate at {self.url}.")
        return client
    except Exception as e:
        logger.error(f"Failed to connect to Weaviate: {e}")
        raise ConnectionError(f"Failed to connect to Weaviate: {e}")
create_index(index_name, vector_dim, distance_metric)

Create a new index (class) in Weaviate.

Parameters:

Name Type Description Default
index_name str

The name of the index.

required
vector_dim int

The dimensionality of the vectors.

required
distance_metric str

The distance metric to use.

required

Raises:

Type Description
WeaviateException

If there is an issue creating the index.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def create_index(self, index_name: str, vector_dim: int, distance_metric: str) -> None:
    """
    Create a new index (class) in Weaviate.

    Args:
        index_name (str): The name of the index.
        vector_dim (int): The dimensionality of the vectors.
        distance_metric (str): The distance metric to use.

    Raises:
        WeaviateException: If there is an issue creating the index.
    """
    try:
        if self.client.schema.contains(index_name):
            logger.info(f"Index {index_name} already exists. Skipping creation.")
            return

        class_obj = {
            "class": index_name,
            "vectorizer": "none",
            "vectorIndexType": "hnsw",
            "vectorIndexConfig": {
                "distance": distance_metric
            },
            "properties": [
                {
                    "name": "id",
                    "dataType": ["string"],
                    "description": "Unique identifier",
                },
                {
                    "name": "payload",
                    "dataType": ["blob"],
                    "description": "Payload data",
                },
            ]
        }

        self.client.schema.create_class(class_obj)
        logger.info(f"Index {index_name} created successfully.")
    except WeaviateException as e:
        logger.error(f"Failed to create index: {e}")
        raise
delete_collection(collection_name)

Delete an entire index (class).

Parameters:

Name Type Description Default
collection_name str

The name of the collection (index) to delete.

required

Raises:

Type Description
WeaviateException

If there is an issue deleting the collection.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def delete_collection(self, collection_name: str) -> None:
    """
    Delete an entire index (class).

    Args:
        collection_name (str): The name of the collection (index) to delete.

    Raises:
        WeaviateException: If there is an issue deleting the collection.
    """
    try:
        self.client.schema.delete_class(collection_name)
        logger.info(f"Deleted index {collection_name}.")
    except WeaviateException as e:
        logger.error(f"Failed to delete collection: {e}")
        raise
delete_vector(collection_name, vector_id)

Delete a vector from the collection by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection (index).

required
vector_id str

The unique identifier of the vector to delete.

required

Raises:

Type Description
ValueError

If collection name does not match.

WeaviateException

If there is an issue deleting the vector.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def delete_vector(self, collection_name: str, vector_id: str) -> None:
    """
    Delete a vector from the collection by its ID.

    Args:
        collection_name (str): The name of the collection (index).
        vector_id (str): The unique identifier of the vector to delete.

    Raises:
        ValueError: If collection name does not match.
        WeaviateException: If there is an issue deleting the vector.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    try:
        self.client.data_object.delete(
            uuid=vector_id,
            class_name=collection_name
        )
        logger.info(f"Deleted vector with ID {vector_id} from index {collection_name}.")
    except WeaviateException as e:
        logger.error(f"Failed to delete vector: {e}")
        raise
get_collection_info(collection_name)

Get information about a collection (index).

Parameters:

Name Type Description Default
collection_name str

The name of the collection (index).

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Information about the collection.

Raises:

Type Description
WeaviateException

If there is an issue retrieving the collection info.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
    """
    Get information about a collection (index).

    Args:
        collection_name (str): The name of the collection (index).

    Returns:
        Dict[str, Any]: Information about the collection.

    Raises:
        WeaviateException: If there is an issue retrieving the collection info.
    """
    try:
        class_schema = self.client.schema.get(class_name=collection_name)
        return class_schema
    except WeaviateException as e:
        logger.error(f"Failed to get collection info: {e}")
        raise
get_vector(collection_name, vector_id)

Retrieve a vector by its ID.

Parameters:

Name Type Description Default
collection_name str

The name of the collection (index).

required
vector_id str

The unique identifier of the vector.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the vector data and payload.

Raises:

Type Description
ValueError

If collection name does not match.

KeyError

If the vector is not found.

WeaviateException

If there is an issue retrieving the vector.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:
    """
    Retrieve a vector by its ID.

    Args:
        collection_name (str): The name of the collection (index).
        vector_id (str): The unique identifier of the vector.

    Returns:
        Dict[str, Any]: A dictionary containing the vector data and payload.

    Raises:
        ValueError: If collection name does not match.
        KeyError: If the vector is not found.
        WeaviateException: If there is an issue retrieving the vector.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    try:
        result = self.client.data_object.get_by_id(
            uuid=vector_id,
            class_name=collection_name,
            additional_properties=["vector"]
        )
        if result is None:
            raise KeyError(f"Vector with ID {vector_id} not found in index {collection_name}.")

        vector_data = {
            "id": result["id"],
            "vector": result["vector"],
            "payload": result["payload"]
        }
        return vector_data
    except WeaviateException as e:
        logger.error(f"Failed to retrieve vector: {e}")
        raise
insert_vectors(collection_name, vectors, payloads=None, ids=None)

Insert vectors into the collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection (index).

required
vectors List[List[float]]

A list of vectors to insert.

required
payloads Optional[List[Dict[str, Any]]]

Optional metadata associated with each vector.

None
ids Optional[List[str]]

Optional unique identifiers for each vector.

None

Raises:

Type Description
ValueError

If input data is invalid.

WeaviateException

If there is an issue inserting vectors.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def insert_vectors(
    self,
    collection_name: str,
    vectors: List[List[float]],
    payloads: Optional[List[Dict[str, Any]]] = None,
    ids: Optional[List[str]] = None
) -> None:
    """
    Insert vectors into the collection.

    Args:
        collection_name (str): The name of the collection (index).
        vectors (List[List[float]]): A list of vectors to insert.
        payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.
        ids (Optional[List[str]]): Optional unique identifiers for each vector.

    Raises:
        ValueError: If input data is invalid.
        WeaviateException: If there is an issue inserting vectors.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    if ids is None:
        raise ValueError("Weaviate requires IDs to be provided for each vector.")

    if payloads is None:
        payloads = [{} for _ in range(len(vectors))]

    if not (len(ids) == len(vectors) == len(payloads)):
        raise ValueError("Lengths of ids, vectors, and payloads must be equal.")

    try:
        with self.client.batch(batch_size=100) as batch:
            for id_, vector, payload in zip(ids, vectors, payloads):
                data_object = {
                    "id": id_,
                    "payload": payload
                }
                batch.add_data_object(
                    data_object=data_object,
                    class_name=collection_name,
                    vector=vector
                )
        logger.info(f"Inserted {len(vectors)} vectors into index {collection_name}.")
    except WeaviateException as e:
        logger.error(f"Failed to insert vectors: {e}")
        raise
list_collections()

List all available indexes (classes).

Returns:

Type Description
List[str]

List[str]: A list of index names.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
361
362
363
364
365
366
367
368
369
370
371
372
373
def list_collections(self) -> List[str]:
    """
    List all available indexes (classes).

    Returns:
        List[str]: A list of index names.
    """
    try:
        schema = self.client.schema.get()
        return [clazz["class"] for clazz in schema["classes"]]
    except WeaviateException as e:
        logger.error(f"Failed to list collections: {e}")
        raise
search_vectors(collection_name, query_vector, top_k=5, filters=None)

Search for similar vectors in the collection.

Parameters:

Name Type Description Default
collection_name str

The name of the collection (index).

required
query_vector List[float]

The vector to search with.

required
top_k int

The number of top results to return.

5
filters Optional[Dict[str, Any]]

Optional filters to apply to the search.

None

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of search results.

Raises:

Type Description
ValueError

If collection name does not match.

WeaviateException

If there is an issue performing the search.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def search_vectors(
    self,
    collection_name: str,
    query_vector: List[float],
    top_k: int = 5,
    filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
    """
    Search for similar vectors in the collection.

    Args:
        collection_name (str): The name of the collection (index).
        query_vector (List[float]): The vector to search with.
        top_k (int): The number of top results to return.
        filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.

    Returns:
        List[Dict[str, Any]]: A list of search results.

    Raises:
        ValueError: If collection name does not match.
        WeaviateException: If there is an issue performing the search.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    try:
        near_vector = {
            "vector": query_vector,
        }

        where_filter = self._build_filters(filters)

        result = self.client.query.get(
            class_name=collection_name,
            properties=["id", "payload"]
        ).with_near_vector(near_vector).with_where(where_filter).with_limit(top_k).do()

        output = []
        for item in result["data"]["Get"][collection_name]:
            result_item = {
                "id": item["id"],
                "score": item["_additional"]["certainty"],  # or distance
                "payload": item["payload"]
            }
            output.append(result_item)
        return output
    except WeaviateException as e:
        logger.error(f"Failed to search vectors: {e}")
        raise
update_vector(collection_name, vector_id, vector=None, payload=None)

Update a vector's data or payload.

Parameters:

Name Type Description Default
collection_name str

The name of the collection (index).

required
vector_id str

The unique identifier of the vector to update.

required
vector Optional[List[float]]

The new vector data.

None
payload Optional[Dict[str, Any]]

The new payload data.

None

Raises:

Type Description
ValueError

If collection name does not match.

WeaviateException

If there is an issue updating the vector.

Source code in src/aeiva/storage/weaviate/weaviate_database.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def update_vector(
    self,
    collection_name: str,
    vector_id: str,
    vector: Optional[List[float]] = None,
    payload: Optional[Dict[str, Any]] = None
) -> None:
    """
    Update a vector's data or payload.

    Args:
        collection_name (str): The name of the collection (index).
        vector_id (str): The unique identifier of the vector to update.
        vector (Optional[List[float]]): The new vector data.
        payload (Optional[Dict[str, Any]]): The new payload data.

    Raises:
        ValueError: If collection name does not match.
        WeaviateException: If there is an issue updating the vector.
    """
    if collection_name != self.index_name:
        raise ValueError("Collection name does not match initialized index name.")

    try:
        data_object = {}
        if payload is not None:
            data_object["payload"] = payload

        self.client.data_object.update(
            data_object=data_object,
            class_name=collection_name,
            uuid=vector_id,
            vector=vector
        )
        logger.info(f"Updated vector with ID {vector_id} in index {collection_name}.")
    except WeaviateException as e:
        logger.error(f"Failed to update vector: {e}")
        raise

tool

api_server

call_api_action(api_name, action_name, request) async

Endpoint to dynamically call an action within a specified API.

Parameters:

Name Type Description Default
api_name str

The name of the API.

required
action_name str

The name of the action/function to execute.

required
request Request

The incoming HTTP request.

required

Returns:

Name Type Description
dict

The result of the action or an error message.

Source code in src/aeiva/tool/api_server.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@app.get("/api/{api_name}/{action_name}")
async def call_api_action(api_name: str, action_name: str, request: Request):
    """
    Endpoint to dynamically call an action within a specified API.

    Args:
        api_name (str): The name of the API.
        action_name (str): The name of the action/function to execute.
        request (Request): The incoming HTTP request.

    Returns:
        dict: The result of the action or an error message.
    """
    try:
        logger.info(f"Starting call_api_action for API '{api_name}', Action '{action_name}'")

        # Load the API module
        module = load_api_module(api_name)

        # Retrieve the action function
        try:
            action = getattr(module, action_name)
            logger.info(f"Retrieved action '{action_name}' from API '{api_name}'")
        except AttributeError:
            logger.error(f"Action '{action_name}' not found in API '{api_name}'")
            raise HTTPException(status_code=404, detail=f"Action '{action_name}' not found in API '{api_name}'")

        # Extract parameters based on request method
        params = {}
        if request.method in ["POST", "PUT", "PATCH"]:
            try:
                params = await request.json()
                logger.info(f"Received JSON payload: {params}")
            except json.JSONDecodeError:
                logger.error("Invalid JSON payload")
                raise HTTPException(status_code=400, detail="Invalid JSON payload")
        else:
            # For GET requests, extract query parameters
            params = dict(request.query_params)
            logger.info(f"Received query parameters: {params}")

        # Get the function signature
        sig = signature(action)
        logger.info(f"Function signature for '{action_name}': {sig}")

        # Prepare to collect converted parameters
        converted_params = {}

        for param_name, param in sig.parameters.items():
            if param_name in params:
                value = params[param_name]
                param_type = param.annotation if param.annotation != Parameter.empty else str
                try:
                    if param_type == bool:
                        # Convert to boolean
                        if isinstance(value, bool):
                            converted_value = value
                        elif isinstance(value, str):
                            converted_value = value.lower() in ("true", "1", "yes")
                        else:
                            converted_value = bool(value)
                    elif param_type in [int, float, str]:
                        converted_value = param_type(value)
                    elif param_type == list or param_type == dict:
                        converted_value = json.loads(value)
                    else:
                        # For more complex types, assume Pydantic models or custom parsing
                        converted_value = param_type(value)
                    converted_params[param_name] = converted_value
                    logger.debug(f"Converted parameter '{param_name}': {converted_value} (Type: {param_type})")
                except (ValueError, json.JSONDecodeError, TypeError) as e:
                    logger.error(f"Invalid value for parameter '{param_name}': {value} ({e})")
                    raise HTTPException(
                        status_code=400,
                        detail=f"Invalid value for parameter '{param_name}': {value}. Expected type {param_type.__name__}."
                    )
            else:
                if param.default == Parameter.empty:
                    logger.error(f"Missing required parameter: {param_name}")
                    raise HTTPException(status_code=400, detail=f"Missing required parameter: {param_name}")
                else:
                    # Use default value
                    converted_params[param_name] = param.default
                    logger.debug(f"Using default value for parameter '{param_name}': {param.default}")

        # Determine if the action is asynchronous
        if asyncio.iscoroutinefunction(action):
            logger.info(f"Action '{action_name}' is asynchronous. Awaiting execution.")
            result = await action(**converted_params)
        else:
            logger.info(f"Action '{action_name}' is synchronous. Executing directly.")
            result = action(**converted_params)

        logger.info(f"Action '{action_name}' executed successfully with result: {result}")
        return {"result": result}

    except FileNotFoundError as e:
        logger.error(f"API module not found: {e}")
        raise HTTPException(status_code=404, detail=str(e))
    except HTTPException as he:
        # Re-raise HTTP exceptions to be handled by FastAPI
        raise he
    except Exception as e:
        logger.error(f"Unhandled exception in call_api_action: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail="Internal Server Error")

load_api_module(api_name)

Dynamically load the API module for the given api_name.

Parameters:

Name Type Description Default
api_name str

The name of the API.

required

Returns:

Name Type Description
module

The loaded API module.

Raises:

Type Description
FileNotFoundError

If the API module does not exist.

ImportError

If the module cannot be imported.

Source code in src/aeiva/tool/api_server.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def load_api_module(api_name: str):
    """
    Dynamically load the API module for the given api_name.

    Args:
        api_name (str): The name of the API.

    Returns:
        module: The loaded API module.

    Raises:
        FileNotFoundError: If the API module does not exist.
        ImportError: If the module cannot be imported.
    """
    # Construct the path to the API module
    api_path = BASE_DIR / "api" / api_name / "api.py"

    if not api_path.exists():
        logger.error(f"API module not found at path: {api_path}")
        raise FileNotFoundError(f"API module not found at path: {api_path}")

    module_name = f"aeiva.tool.api.{api_name}.api"
    spec = importlib.util.spec_from_file_location(module_name, str(api_path))
    module = importlib.util.module_from_spec(spec)
    try:
        spec.loader.exec_module(module)
        logger.info(f"Successfully loaded module '{module_name}'")
    except Exception as e:
        logger.error(f"Failed to load module '{module_name}': {e}")
        raise ImportError(f"Failed to load module '{module_name}': {e}")
    return module

root() async

Root endpoint to confirm the API server is running.

Source code in src/aeiva/tool/api_server.py
55
56
57
58
59
60
@app.get("/")
async def root():
    """
    Root endpoint to confirm the API server is running.
    """
    return {"message": "Welcome to the AI Agent API system!"}

tool

Tool

Source code in src/aeiva/tool/tool.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class Tool:
    def __init__(self, api_name: str):
        """
        Initialize the tool, determining whether it should run locally or via an external service.
        Args:
            api_name (str): The name of the tool API (matches the function name).
        """
        self.api_name = api_name
        self.schema = self.load_tool_schema(api_name)

    @classmethod
    def load_tool_schema(cls, api_name: str) -> dict:
        """
        Load the tool's schema from the JSON file.
        Args:
            api_name (str): The name of the API or function.
        Returns:
            dict: The loaded schema from the JSON file.
        """
        current_path = os.path.dirname(os.path.abspath(__file__))
        project_root = os.path.abspath(os.path.join(current_path, "../../.."))
        path = os.path.join(
            project_root,
            f"src/aeiva/tool/api/{api_name}/{api_name}.json",
        )
        with open(path, "r") as file:
            return json.load(file)

    async def aexecute(self, params: dict) -> Any:
        """
        Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call).
        Args:
            params (dict): Parameters to pass to the tool.
        Returns:
            Any: The result of the tool execution.
        """
        function_module = f"aeiva.tool.api.{self.api_name}.api"
        func_module = import_module(function_module)

        # Check if the function is async
        function: Callable = getattr(func_module, self.api_name)
        if asyncio.iscoroutinefunction(function):
            return await function(**params)
        else:
            return function(**params)

    def execute(self, params: dict) -> Any:
        """
        Execute the tool synchronously by calling the corresponding function.

        Args:
            params (dict): Parameters to pass to the tool.

        Returns:
            Any: The result of the tool execution.
        """
        function_module = f"aeiva.tool.api.{self.api_name}.api"
        func_module = import_module(function_module)

        function: Callable = getattr(func_module, self.api_name)
        if asyncio.iscoroutinefunction(function):
            # If the function is async, attempt to run it in an event loop
            try:
                loop = asyncio.get_running_loop()
                # If an event loop is running, create a task and wait for it
                task = loop.create_task(function(**params))
                return loop.run_until_complete(task)
            except RuntimeError:
                # No event loop running, use asyncio.run
                return asyncio.run(function(**params))
        else:
            # If the function is synchronous, call it directly
            return function(**params)
__init__(api_name)

Initialize the tool, determining whether it should run locally or via an external service. Args: api_name (str): The name of the tool API (matches the function name).

Source code in src/aeiva/tool/tool.py
12
13
14
15
16
17
18
19
def __init__(self, api_name: str):
    """
    Initialize the tool, determining whether it should run locally or via an external service.
    Args:
        api_name (str): The name of the tool API (matches the function name).
    """
    self.api_name = api_name
    self.schema = self.load_tool_schema(api_name)
aexecute(params) async

Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call). Args: params (dict): Parameters to pass to the tool. Returns: Any: The result of the tool execution.

Source code in src/aeiva/tool/tool.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
async def aexecute(self, params: dict) -> Any:
    """
    Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call).
    Args:
        params (dict): Parameters to pass to the tool.
    Returns:
        Any: The result of the tool execution.
    """
    function_module = f"aeiva.tool.api.{self.api_name}.api"
    func_module = import_module(function_module)

    # Check if the function is async
    function: Callable = getattr(func_module, self.api_name)
    if asyncio.iscoroutinefunction(function):
        return await function(**params)
    else:
        return function(**params)
execute(params)

Execute the tool synchronously by calling the corresponding function.

Parameters:

Name Type Description Default
params dict

Parameters to pass to the tool.

required

Returns:

Name Type Description
Any Any

The result of the tool execution.

Source code in src/aeiva/tool/tool.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def execute(self, params: dict) -> Any:
    """
    Execute the tool synchronously by calling the corresponding function.

    Args:
        params (dict): Parameters to pass to the tool.

    Returns:
        Any: The result of the tool execution.
    """
    function_module = f"aeiva.tool.api.{self.api_name}.api"
    func_module = import_module(function_module)

    function: Callable = getattr(func_module, self.api_name)
    if asyncio.iscoroutinefunction(function):
        # If the function is async, attempt to run it in an event loop
        try:
            loop = asyncio.get_running_loop()
            # If an event loop is running, create a task and wait for it
            task = loop.create_task(function(**params))
            return loop.run_until_complete(task)
        except RuntimeError:
            # No event loop running, use asyncio.run
            return asyncio.run(function(**params))
    else:
        # If the function is synchronous, call it directly
        return function(**params)
load_tool_schema(api_name) classmethod

Load the tool's schema from the JSON file. Args: api_name (str): The name of the API or function. Returns: dict: The loaded schema from the JSON file.

Source code in src/aeiva/tool/tool.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@classmethod
def load_tool_schema(cls, api_name: str) -> dict:
    """
    Load the tool's schema from the JSON file.
    Args:
        api_name (str): The name of the API or function.
    Returns:
        dict: The loaded schema from the JSON file.
    """
    current_path = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.abspath(os.path.join(current_path, "../../.."))
    path = os.path.join(
        project_root,
        f"src/aeiva/tool/api/{api_name}/{api_name}.json",
    )
    with open(path, "r") as file:
        return json.load(file)

toolkit

arxiv_toolkit

ArxivToolkit

Bases: Toolkit

A toolkit for interacting with the arXiv API.

Source code in src/aeiva/tool/toolkit/arxiv_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class ArxivToolkit(Toolkit):
    """
    A toolkit for interacting with the arXiv API.
    """

    def __init__(self, config=None):
        super().__init__(
            name="ArxivToolkit",
            tool_names=[
                "download_arxiv_papers",
                "search_arxiv_papers"
            ],
            config=config
        )

auto_ui_toolkit

AutoUIToolkit

Bases: Toolkit

A toolkit for automating GUI interactions.

Source code in src/aeiva/tool/toolkit/auto_ui_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class AutoUIToolkit(Toolkit):
    """
    A toolkit for automating GUI interactions.
    """

    def __init__(self, config=None):
        super().__init__(
            name="AutoUIToolkit",
            tool_names=[
                "analyze_gui",
                "analyze_gui_by_ocr",
                "click_mouse",
                "click_on_element",
                "move_mouse",
                "operate_computer",
                "scroll",
                "type_into_element",
                "type_keyboard"
            ],
            config=config
        )

docx_toolkit

DocxToolkit

Bases: Toolkit

A toolkit for interacting with Docx files.

Source code in src/aeiva/tool/toolkit/docx_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class DocxToolkit(Toolkit):
    """
    A toolkit for interacting with Docx files.
    """

    def __init__(self, config=None):
        super().__init__(
            name="DocxToolkit",
            tool_names=[
                "create_docx",
                "docx2html",
                "docx2images",
                "docx2markdown",
                "docx2metadata",
                "docx2pdf",
                "docx2text",
                "modify_docx"
            ],
            config=config
        )

file_toolkit

FileToolkit

Bases: Toolkit

A toolkit for file-related operations.

Source code in src/aeiva/tool/toolkit/file_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class FileToolkit(Toolkit):
    """
    A toolkit for file-related operations.
    """

    def __init__(self, config=None):
        super().__init__(
            name="FileToolkit",
            tool_names=[
                "create_file_or_folder",
                "open_file_or_folder",
                "search_file_or_folder",
                "copy_file_or_folder",
                "move_file_or_folder",
                "change_permissions",
                "get_file_metadata",
                "delete_file",
                "edit_file",
                "find_file",
                "list_files",
                "read_file",
                "rename_file",
                "write_file"
            ],
            config=config
        )

git_toolkit

GitToolkit

Bases: Toolkit

A toolkit for interacting with Git repositories.

Source code in src/aeiva/tool/toolkit/git_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class GitToolkit(Toolkit):
    """
    A toolkit for interacting with Git repositories.
    """

    def __init__(self, config=None):
        super().__init__(
            name="GitToolkit",
            tool_names=[
                "git_apply_patch",
                "git_clone",
                "git_custom",
                "git_patch",
                "git_repo_tree"
            ],
            config=config
        )

math_toolkit

MathToolkit

Bases: Toolkit

A toolkit for mathematical operations.

Source code in src/aeiva/tool/toolkit/math_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
class MathToolkit(Toolkit):
    """
    A toolkit for mathematical operations.
    """

    def __init__(self, config=None):
        super().__init__(
            name="MathToolkit",
            tool_names=["calculator"],
            config=config
        )

pdf_toolkit

PdfToolkit

Bases: Toolkit

A toolkit for interacting with PDF files.

Source code in src/aeiva/tool/toolkit/pdf_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class PdfToolkit(Toolkit):
    """
    A toolkit for interacting with PDF files.
    """

    def __init__(self, config=None):
        super().__init__(
            name="PdfToolkit",
            tool_names=[
                "pdf2markdown",
                "pdf2text",
                "pdf2tables",
                "pdf2images",
                "pdf2metadata",
                "pdf2ocr"
            ],
            config=config
        )

rbac

PermissionError

Bases: Exception

Custom exception for permission-related errors.

Source code in src/aeiva/tool/toolkit/rbac.py
6
7
8
class PermissionError(Exception):
    """Custom exception for permission-related errors."""
    pass
check_permission(user_role, api_name, config)

Check if the user_role has permission to execute the given api_name.

Parameters:

Name Type Description Default
user_role str

The role of the user.

required
api_name str

The name of the API function.

required
config ToolkitConfig

The toolkit configuration containing role permissions.

required

Returns:

Name Type Description
bool bool

True if permitted, False otherwise.

Raises:

Type Description
PermissionError

If the user does not have the required permission.

Source code in src/aeiva/tool/toolkit/rbac.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def check_permission(user_role: str, api_name: str, config: ToolkitConfig) -> bool:
    """
    Check if the user_role has permission to execute the given api_name.

    Args:
        user_role (str): The role of the user.
        api_name (str): The name of the API function.
        config (ToolkitConfig): The toolkit configuration containing role permissions.

    Returns:
        bool: True if permitted, False otherwise.

    Raises:
        PermissionError: If the user does not have the required permission.
    """
    allowed_apis: List[str] = config.role_permissions.get(user_role, [])
    if api_name in allowed_apis:
        return True
    else:
        return False

security

sanitize_file_path(file_path, config)

Sanitize the file path to prevent directory traversal attacks.

Parameters:

Name Type Description Default
file_path str

The input file path.

required
config ToolkitConfig

The configuration instance.

required

Returns:

Name Type Description
str str

The sanitized absolute file path.

Raises:

Type Description
ValueError

If the file path is not within allowed directories.

Source code in src/aeiva/tool/toolkit/security.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def sanitize_file_path(file_path: str, config: ToolkitConfig) -> str:
    """
    Sanitize the file path to prevent directory traversal attacks.

    Args:
        file_path (str): The input file path.
        config (ToolkitConfig): The configuration instance.

    Returns:
        str: The sanitized absolute file path.

    Raises:
        ValueError: If the file path is not within allowed directories.
    """
    # Resolve the absolute path
    try:
        absolute_path = Path(file_path).resolve(strict=False)
    except Exception as e:
        logger.error(f"Error resolving file path '{file_path}': {e}")
        raise ValueError(f"Invalid file path: {e}")

    # Check if the path is within allowed directories
    allowed = False
    for dir_path in config.allowed_directories:
        try:
            allowed_dir = Path(dir_path).resolve(strict=False)
            if allowed_dir in absolute_path.parents or allowed_dir == absolute_path.parent:
                allowed = True
                break
        except Exception as e:
            logger.error(f"Error resolving allowed directory '{dir_path}': {e}")
            continue

    if not allowed:
        logger.error(f"Unauthorized file path access attempt: {absolute_path}")
        raise ValueError("Unauthorized file path.")

    return str(absolute_path)

shell_toolkit

ShellToolkit

Bases: Toolkit

A toolkit for shell and terminal operations.

Source code in src/aeiva/tool/toolkit/shell_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class ShellToolkit(Toolkit):
    """
    A toolkit for shell and terminal operations.
    """

    def __init__(self, config=None):
        super().__init__(
            name="ShellToolkit",
            tool_names=[
                "chwdir",
                "execute_bash_command",
                "execute_script",
                "grep",
                "create_new_shell_session"
            ],
            config=config
        )

system_toolkit

SystemToolkit

Bases: Toolkit

A toolkit for interacting with system-level operations.

Source code in src/aeiva/tool/toolkit/system_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class SystemToolkit(Toolkit):
    """
    A toolkit for interacting with system-level operations.
    """

    def __init__(self, config=None):
        super().__init__(
            name="SystemToolkit",
            tool_names=[
                "get_system_info",
                "get_package_root",
                "get_user_home_path",
                "open_application",
                "close_application",
                "percept_terminal_input",
                "play_music",
                "stop_music",
                "take_screenshot"
                "list_processes",
                "kill_process",
                "monitor_process",
                "get_network_info",
                "check_internet_connection",
                "get_disk_usage",
                "clean_temp_files",
                "list_drives",
                "get_env_var",
                "set_env_var",
                "update_system_packages",
                "install_package",
                "create_user",
                "delete_user",
                "change_user_password",
                "view_system_logs",
                "monitor_system_resources",
            ],
            config=config
        )

toolkit

Toolkit

Toolkit class that manages multiple Tool instances, handles validation, enforces RBAC, and manages shared resources.

Source code in src/aeiva/tool/toolkit/toolkit.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
class Toolkit:
    """
    Toolkit class that manages multiple Tool instances, handles validation,
    enforces RBAC, and manages shared resources.
    """

    subclasses: Dict[str, Type['Toolkit']] = {}

    def __init_subclass__(cls, **kwargs):
        """
        Automatically register subclasses in the Toolkit's subclasses dictionary.
        """
        super().__init_subclass__(**kwargs)
        Toolkit.subclasses[cls.__name__] = cls

    def __init__(self, name: str, tool_names: List[str], config: Optional[ToolkitConfig] = None):
        """
        Initialize the Toolkit with a name, list of tool names, and optional configuration.

        Args:
            name (str): The name of the toolkit.
            tool_names (List[str]): The names of tools to be managed by the toolkit.
            config (Optional[ToolkitConfig]): Configuration for security and roles.
        """
        self.toolkit_name = name
        self.tool_names = tool_names
        self.config = config
        self.tools: Dict[str, Tool] = {}
        self.tool_schemas: Dict[str, Dict] = {}
        self.tool_models: Dict[str, Tuple[Type[BaseModel], Type[BaseModel]]] = {}
        self.shared_resources = None  # Placeholder for shared resources

        # Setup the toolkit
        self.setup()

    def setup(self):
        """
        Setup the toolkit by loading tools, their schemas, and initializing shared resources.
        """
        logger.info(f"Setting up toolkit '{self.toolkit_name}'.")

        # Load tools and their schemas
        for tool_name in self.tool_names:
            tool = Tool(api_name=tool_name)
            self.tools[tool_name] = tool
            schema = tool.load_tool_schema(tool_name)
            self.tool_schemas[tool_name] = schema
            logger.debug(f"Loaded schema for tool '{tool_name}': {schema}")

        # Load Pydantic models for validation
        self.load_pydantic_models_for_all_tools()

        # Initialize shared resources
        self.init_shared_resources()

    def load_pydantic_models_for_all_tools(self):
        """
        Load Pydantic models (Params and Result) for all tools for validation.
        """
        logger.info("Loading Pydantic models for all tools.")
        for tool_name in self.tool_names:
            try:
                param_model, result_model = self.load_pydantic_models_for_tool(tool_name)
                self.tool_models[tool_name] = (param_model, result_model)
                logger.debug(f"Loaded models for tool '{tool_name}': Params={param_model}, Result={result_model}")
            except Exception as e:
                logger.error(f"Failed to load models for tool '{tool_name}': {e}")
                raise

    def load_pydantic_models_for_tool(self, api_name: str) -> Tuple[Type[BaseModel], Type[BaseModel]]:
        """
        Load the parameter and result Pydantic models for the given API.

        Args:
            api_name (str): The name of the API function.

        Returns:
            Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.

        Raises:
            ValueError: If models cannot be loaded.
        """
        module_path = f"aeiva.tool.api.{api_name}.model"  # Adjusted as per user's path
        try:
            models_module = importlib.import_module(module_path)
            param_model_class = getattr(models_module, f"{snake_to_camel(api_name)}Params", None)
            result_model_class = getattr(models_module, f"{snake_to_camel(api_name)}Result", None)
            if not (param_model_class and issubclass(param_model_class, BaseModel)):
                logger.error(f"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.")
                raise ValueError(f"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.")
            if not (result_model_class and issubclass(result_model_class, BaseModel)):
                logger.error(f"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.")
                raise ValueError(f"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.")
            return param_model_class, result_model_class
        except ImportError as e:
            logger.error(f"Error importing models from '{module_path}': {e}")
            raise ImportError(f"Error importing models from '{module_path}': {e}")
        except AttributeError as e:
            logger.error(f"Error accessing model classes in '{module_path}': {e}")
            raise ValueError(f"Error accessing model classes in '{module_path}': {e}")

    def init_shared_resources(self):
        """
        Initialize shared resources required by the toolkit.
        Override this method in subclasses if needed.
        """
        logger.info("Initializing shared resources.")
        # Placeholder for initializing shared resources like databases, servers, etc.
        # Example:
        # self.shared_resources = initialize_database_connection()
        pass

    def teardown(self):
        """
        Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.
        """
        logger.info(f"Tearing down toolkit '{self.toolkit_name}'.")

        # Clean up shared resources
        self.teardown_shared_resources()

        # Clear loaded data
        self.tools.clear()
        self.tool_schemas.clear()
        self.tool_models.clear()

    def teardown_shared_resources(self):
        """
        Teardown shared resources.
        Override this method in subclasses if needed.
        """
        logger.info("Tearing down shared resources.")
        # Placeholder for tearing down shared resources
        # Example:
        # if self.shared_resources:
        #     self.shared_resources.close()
        pass

    @asynccontextmanager
    async def acontext(self):
        """
        Asynchronous context manager to handle setup and teardown of shared resources.

        Usage:
            async with toolkit.acontent():
                # Execute tools
        """
        try:
            await self.asetup()
            yield self
        finally:
            await self.ateardown()

    @contextmanager
    def context(self):
        """
        Synchronous context manager to handle setup and teardown of shared resources.

        Usage:
            with toolkit.context():
                # Execute tools
        """
        try:
            self.setup()
            yield self
        finally:
            self.teardown()

    async def asetup(self):
        """
        Asynchronously setup shared resources.
        """
        logger.info(f"Asynchronously setting up toolkit '{self.toolkit_name}'.")
        # Override in subclasses if asynchronous setup is required
        pass

    async def ateardown(self):
        """
        Asynchronously teardown shared resources.
        """
        logger.info(f"Asynchronously tearing down toolkit '{self.toolkit_name}'.")
        # Override in subclasses if asynchronous teardown is required
        self.teardown()

    def execute(self, api_name: str, params: Dict[str, Any]) -> Any:
        """
        Synchronously execute a tool's API function with validation and RBAC checks.

        Args:
            api_name (str): The name of the API function to execute.
            params (Dict[str, Any]): The parameters for the API function.

        Returns:
            Any: The result of the tool execution.

        Raises:
            ValueError: If tool not found or parameter validation fails.
            PermissionError: If user does not have permission.
            RuntimeError: If tool execution fails.
        """
        tool = self.tools.get(api_name)
        if not tool:
            logger.error(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")
            raise ValueError(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")

        # Perform RBAC check if config is provided
        if self.config:
            # Automatically retrieve user role from OS
            os_user = get_os_user()
            user_role = self.config.user_role_mapping.get(os_user)
            if not user_role:
                logger.error(f"OS user '{os_user}' does not have an assigned role.")
                raise ValueError(f"OS user '{os_user}' does not have an assigned role.")
            if not check_permission(user_role, api_name, self.config):
                logger.error(f"User role '{user_role}' does not have permission to execute '{api_name}'.")
                raise PermissionError(f"User role '{user_role}' does not have permission to execute '{api_name}'.")

        # Load the Pydantic models for validation
        param_model, result_model = self.tool_models.get(api_name, (None, None))
        if not param_model or not result_model:
            logger.error(f"Pydantic models for tool '{api_name}' are not loaded.")
            raise ValueError(f"Pydantic models for tool '{api_name}' are not loaded.")

        # Instantiate and validate the parameter model
        try:
            param_instance = param_model(**params)
            logger.debug(f"Validated input parameters for '{api_name}': {param_instance}")
        except Exception as e:
            logger.error(f"Error parsing parameters for '{api_name}': {e}")
            raise ValueError(f"Invalid parameters for '{api_name}': {e}")

        # Perform security checks on parameters if needed
        param_instance = self.perform_security_checks(param_instance)

        # Execute the API function via the Tool
        try:
            raw_result = tool.execute(param_instance.dict())
            logger.debug(f"Raw result from '{api_name}': {raw_result}")
        except Exception as e:
            logger.error(f"Error executing tool '{api_name}': {e}")
            raise RuntimeError(f"Error executing tool '{api_name}': {e}")

        # Validate the result using the result model
        try:
            result_instance = result_model(**raw_result)
            logger.info(f"Execution of '{api_name}' successful with result: {result_instance}")
            return result_instance
        except Exception as e:
            logger.error(f"Error parsing result for '{api_name}': {e}")
            raise ValueError(f"Invalid result from '{api_name}': {e}")

    async def aexecute(self, api_name: str, params: Dict[str, Any]) -> Any:
        """
        Asynchronously execute a tool's API function with validation and RBAC checks.

        Args:
            api_name (str): The name of the API function to execute.
            params (Dict[str, Any]): The parameters for the API function.

        Returns:
            Any: The result of the tool execution.

        Raises:
            ValueError: If tool not found or parameter validation fails.
            PermissionError: If user does not have permission.
            RuntimeError: If tool execution fails.
        """
        tool = self.tools.get(api_name)
        if not tool:
            logger.error(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")
            raise ValueError(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")

        # Perform RBAC check if config is provided
        if self.config:
            # Automatically retrieve user role from OS
            os_user = get_os_user()
            user_role = self.config.user_role_mapping.get(os_user)
            if not user_role:
                logger.error(f"OS user '{os_user}' does not have an assigned role.")
                raise ValueError(f"OS user '{os_user}' does not have an assigned role.")
            if not check_permission(user_role, api_name, self.config):
                logger.error(f"User role '{user_role}' does not have permission to execute '{api_name}'.")
                raise PermissionError(f"User role '{user_role}' does not have permission to execute '{api_name}'.")

        # Load the Pydantic models for validation
        param_model, result_model = self.tool_models.get(api_name, (None, None))
        if not param_model or not result_model:
            logger.error(f"Pydantic models for tool '{api_name}' are not loaded.")
            raise ValueError(f"Pydantic models for tool '{api_name}' are not loaded.")

        # Instantiate and validate the parameter model
        try:
            param_instance = param_model(**params)
            logger.debug(f"Validated input parameters for '{api_name}': {param_instance}")
        except Exception as e:
            logger.error(f"Error parsing parameters for '{api_name}': {e}")
            raise ValueError(f"Invalid parameters for '{api_name}': {e}")

        # Perform security checks on parameters if needed
        param_instance = self.perform_security_checks(param_instance)

        # Execute the API function via the Tool
        try:
            raw_result = await tool.aexecute(param_instance.dict())
            logger.debug(f"Raw result from '{api_name}': {raw_result}")
        except Exception as e:
            logger.error(f"Error executing tool '{api_name}': {e}")
            raise RuntimeError(f"Error executing tool '{api_name}': {e}")

        # Validate the result using the result model
        try:
            result_instance = result_model(**raw_result)
            logger.info(f"Execution of '{api_name}' successful with result: {result_instance}")
            return result_instance
        except Exception as e:
            logger.error(f"Error parsing result for '{api_name}': {e}")
            raise ValueError(f"Invalid result from '{api_name}': {e}")

    def perform_security_checks(self, param_instance: BaseModel) -> BaseModel:
        """
        Perform security checks on parameters that require sanitization.

        Args:
            param_instance (BaseModel): The validated parameter instance.

        Returns:
            BaseModel: The sanitized parameter instance.

        Raises:
            ValueError: If sanitization fails for any field or if config is required but not provided.
        """
        sanitized_params = param_instance.dict()

        for field_name, field in param_instance.__fields__.items():
            sanitize = field.field_info.extra.get('sanitize', False)
            if not sanitize:
                continue  # Skip fields that do not require sanitization

            field_type = field.type_
            origin = get_origin(field_type)
            args = get_args(field_type)

            # Determine if the field is a string type or contains string types
            is_string_field = False

            if field_type == str:
                is_string_field = True
            elif origin is Union and str in args:
                is_string_field = True
            elif origin is list and len(args) == 1 and args[0] == str:
                is_string_field = True
            elif origin is Optional and str in args:
                is_string_field = True
            # Add more conditions here if there are other complex types containing strings

            if is_string_field:
                original_value = sanitized_params.get(field_name)
                if original_value is None:
                    continue  # Skip if the field value is None

                if self.config is None:
                    logger.error(
                        f"Configuration is required to sanitize field '{field_name}', "
                        f"but config is not provided."
                    )
                    raise ValueError(
                        f"Configuration is required to sanitize field '{field_name}', "
                        f"but config is not provided."
                    )

                try:
                    # If the field is a list of strings, sanitize each path individually
                    if origin is list and len(args) == 1 and args[0] == str:
                        if not isinstance(original_value, list):
                            logger.error(
                                f"Expected a list for field '{field_name}', "
                                f"got {type(original_value)}."
                            )
                            raise ValueError(
                                f"Expected a list for field '{field_name}'."
                            )
                        sanitized_list = []
                        for idx, item in enumerate(original_value):
                            sanitized_item = sanitize_file_path(item, self.config)
                            sanitized_list.append(sanitized_item)
                            logger.debug(
                                f"Sanitized '{field_name}[{idx}]': '{item}' -> '{sanitized_item}'"
                            )
                        sanitized_params[field_name] = sanitized_list
                    else:
                        # Sanitize single string path
                        sanitized_path = sanitize_file_path(original_value, self.config)
                        sanitized_params[field_name] = sanitized_path
                        logger.debug(
                            f"Sanitized '{field_name}': '{original_value}' -> '{sanitized_path}'"
                        )
                except ValueError as ve:
                    logger.error(
                        f"Sanitization failed for field '{field_name}': {ve}"
                    )
                    raise

        # Create a new instance of the parameter model with sanitized parameters
        sanitized_instance = param_instance.copy(update=sanitized_params)

        return sanitized_instance

    def generate_documentation(self) -> str:
        """
        Generate documentation for all functions managed by this toolkit based on their schemas.

        Returns:
            str: Generated documentation as a markdown string.
        """
        doc = f"# Toolkit: {self.toolkit_name}\n\n"
        for api_name, tool in self.tools.items():
            schema = self.tool_schemas.get(api_name, {})
            if not schema:
                continue
            doc += f"## Function: {api_name}\n\n"
            doc += f"**Description:** {schema.get('description', 'No description provided.')}\n\n"
            doc += "### Parameters:\n\n"
            parameters = schema.get("parameters", {})
            for prop, details in parameters.get("properties", {}).items():
                req = " (required)" if prop in parameters.get("required", []) else ""
                description = details.get("description", "")
                default = f" (default: {details.get('default')})" if "default" in details else ""
                doc += f"- **{prop}** ({details.get('type', 'any')}): {description}{default}{req}\n"
            doc += "\n### Example:\n\n"
            example = schema.get("example", "No example provided.")
            if isinstance(example, dict):
                example = json.dumps(example, indent=4)
            doc += f"```json\n{example}\n```\n\n"
        return doc
__init__(name, tool_names, config=None)

Initialize the Toolkit with a name, list of tool names, and optional configuration.

Parameters:

Name Type Description Default
name str

The name of the toolkit.

required
tool_names List[str]

The names of tools to be managed by the toolkit.

required
config Optional[ToolkitConfig]

Configuration for security and roles.

None
Source code in src/aeiva/tool/toolkit/toolkit.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(self, name: str, tool_names: List[str], config: Optional[ToolkitConfig] = None):
    """
    Initialize the Toolkit with a name, list of tool names, and optional configuration.

    Args:
        name (str): The name of the toolkit.
        tool_names (List[str]): The names of tools to be managed by the toolkit.
        config (Optional[ToolkitConfig]): Configuration for security and roles.
    """
    self.toolkit_name = name
    self.tool_names = tool_names
    self.config = config
    self.tools: Dict[str, Tool] = {}
    self.tool_schemas: Dict[str, Dict] = {}
    self.tool_models: Dict[str, Tuple[Type[BaseModel], Type[BaseModel]]] = {}
    self.shared_resources = None  # Placeholder for shared resources

    # Setup the toolkit
    self.setup()
__init_subclass__(**kwargs)

Automatically register subclasses in the Toolkit's subclasses dictionary.

Source code in src/aeiva/tool/toolkit/toolkit.py
31
32
33
34
35
36
def __init_subclass__(cls, **kwargs):
    """
    Automatically register subclasses in the Toolkit's subclasses dictionary.
    """
    super().__init_subclass__(**kwargs)
    Toolkit.subclasses[cls.__name__] = cls
acontext() async

Asynchronous context manager to handle setup and teardown of shared resources.

Usage

async with toolkit.acontent(): # Execute tools

Source code in src/aeiva/tool/toolkit/toolkit.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@asynccontextmanager
async def acontext(self):
    """
    Asynchronous context manager to handle setup and teardown of shared resources.

    Usage:
        async with toolkit.acontent():
            # Execute tools
    """
    try:
        await self.asetup()
        yield self
    finally:
        await self.ateardown()
aexecute(api_name, params) async

Asynchronously execute a tool's API function with validation and RBAC checks.

Parameters:

Name Type Description Default
api_name str

The name of the API function to execute.

required
params Dict[str, Any]

The parameters for the API function.

required

Returns:

Name Type Description
Any Any

The result of the tool execution.

Raises:

Type Description
ValueError

If tool not found or parameter validation fails.

PermissionError

If user does not have permission.

RuntimeError

If tool execution fails.

Source code in src/aeiva/tool/toolkit/toolkit.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
async def aexecute(self, api_name: str, params: Dict[str, Any]) -> Any:
    """
    Asynchronously execute a tool's API function with validation and RBAC checks.

    Args:
        api_name (str): The name of the API function to execute.
        params (Dict[str, Any]): The parameters for the API function.

    Returns:
        Any: The result of the tool execution.

    Raises:
        ValueError: If tool not found or parameter validation fails.
        PermissionError: If user does not have permission.
        RuntimeError: If tool execution fails.
    """
    tool = self.tools.get(api_name)
    if not tool:
        logger.error(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")
        raise ValueError(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")

    # Perform RBAC check if config is provided
    if self.config:
        # Automatically retrieve user role from OS
        os_user = get_os_user()
        user_role = self.config.user_role_mapping.get(os_user)
        if not user_role:
            logger.error(f"OS user '{os_user}' does not have an assigned role.")
            raise ValueError(f"OS user '{os_user}' does not have an assigned role.")
        if not check_permission(user_role, api_name, self.config):
            logger.error(f"User role '{user_role}' does not have permission to execute '{api_name}'.")
            raise PermissionError(f"User role '{user_role}' does not have permission to execute '{api_name}'.")

    # Load the Pydantic models for validation
    param_model, result_model = self.tool_models.get(api_name, (None, None))
    if not param_model or not result_model:
        logger.error(f"Pydantic models for tool '{api_name}' are not loaded.")
        raise ValueError(f"Pydantic models for tool '{api_name}' are not loaded.")

    # Instantiate and validate the parameter model
    try:
        param_instance = param_model(**params)
        logger.debug(f"Validated input parameters for '{api_name}': {param_instance}")
    except Exception as e:
        logger.error(f"Error parsing parameters for '{api_name}': {e}")
        raise ValueError(f"Invalid parameters for '{api_name}': {e}")

    # Perform security checks on parameters if needed
    param_instance = self.perform_security_checks(param_instance)

    # Execute the API function via the Tool
    try:
        raw_result = await tool.aexecute(param_instance.dict())
        logger.debug(f"Raw result from '{api_name}': {raw_result}")
    except Exception as e:
        logger.error(f"Error executing tool '{api_name}': {e}")
        raise RuntimeError(f"Error executing tool '{api_name}': {e}")

    # Validate the result using the result model
    try:
        result_instance = result_model(**raw_result)
        logger.info(f"Execution of '{api_name}' successful with result: {result_instance}")
        return result_instance
    except Exception as e:
        logger.error(f"Error parsing result for '{api_name}': {e}")
        raise ValueError(f"Invalid result from '{api_name}': {e}")
asetup() async

Asynchronously setup shared resources.

Source code in src/aeiva/tool/toolkit/toolkit.py
191
192
193
194
195
196
197
async def asetup(self):
    """
    Asynchronously setup shared resources.
    """
    logger.info(f"Asynchronously setting up toolkit '{self.toolkit_name}'.")
    # Override in subclasses if asynchronous setup is required
    pass
ateardown() async

Asynchronously teardown shared resources.

Source code in src/aeiva/tool/toolkit/toolkit.py
199
200
201
202
203
204
205
async def ateardown(self):
    """
    Asynchronously teardown shared resources.
    """
    logger.info(f"Asynchronously tearing down toolkit '{self.toolkit_name}'.")
    # Override in subclasses if asynchronous teardown is required
    self.teardown()
context()

Synchronous context manager to handle setup and teardown of shared resources.

Usage

with toolkit.context(): # Execute tools

Source code in src/aeiva/tool/toolkit/toolkit.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@contextmanager
def context(self):
    """
    Synchronous context manager to handle setup and teardown of shared resources.

    Usage:
        with toolkit.context():
            # Execute tools
    """
    try:
        self.setup()
        yield self
    finally:
        self.teardown()
execute(api_name, params)

Synchronously execute a tool's API function with validation and RBAC checks.

Parameters:

Name Type Description Default
api_name str

The name of the API function to execute.

required
params Dict[str, Any]

The parameters for the API function.

required

Returns:

Name Type Description
Any Any

The result of the tool execution.

Raises:

Type Description
ValueError

If tool not found or parameter validation fails.

PermissionError

If user does not have permission.

RuntimeError

If tool execution fails.

Source code in src/aeiva/tool/toolkit/toolkit.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def execute(self, api_name: str, params: Dict[str, Any]) -> Any:
    """
    Synchronously execute a tool's API function with validation and RBAC checks.

    Args:
        api_name (str): The name of the API function to execute.
        params (Dict[str, Any]): The parameters for the API function.

    Returns:
        Any: The result of the tool execution.

    Raises:
        ValueError: If tool not found or parameter validation fails.
        PermissionError: If user does not have permission.
        RuntimeError: If tool execution fails.
    """
    tool = self.tools.get(api_name)
    if not tool:
        logger.error(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")
        raise ValueError(f"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.")

    # Perform RBAC check if config is provided
    if self.config:
        # Automatically retrieve user role from OS
        os_user = get_os_user()
        user_role = self.config.user_role_mapping.get(os_user)
        if not user_role:
            logger.error(f"OS user '{os_user}' does not have an assigned role.")
            raise ValueError(f"OS user '{os_user}' does not have an assigned role.")
        if not check_permission(user_role, api_name, self.config):
            logger.error(f"User role '{user_role}' does not have permission to execute '{api_name}'.")
            raise PermissionError(f"User role '{user_role}' does not have permission to execute '{api_name}'.")

    # Load the Pydantic models for validation
    param_model, result_model = self.tool_models.get(api_name, (None, None))
    if not param_model or not result_model:
        logger.error(f"Pydantic models for tool '{api_name}' are not loaded.")
        raise ValueError(f"Pydantic models for tool '{api_name}' are not loaded.")

    # Instantiate and validate the parameter model
    try:
        param_instance = param_model(**params)
        logger.debug(f"Validated input parameters for '{api_name}': {param_instance}")
    except Exception as e:
        logger.error(f"Error parsing parameters for '{api_name}': {e}")
        raise ValueError(f"Invalid parameters for '{api_name}': {e}")

    # Perform security checks on parameters if needed
    param_instance = self.perform_security_checks(param_instance)

    # Execute the API function via the Tool
    try:
        raw_result = tool.execute(param_instance.dict())
        logger.debug(f"Raw result from '{api_name}': {raw_result}")
    except Exception as e:
        logger.error(f"Error executing tool '{api_name}': {e}")
        raise RuntimeError(f"Error executing tool '{api_name}': {e}")

    # Validate the result using the result model
    try:
        result_instance = result_model(**raw_result)
        logger.info(f"Execution of '{api_name}' successful with result: {result_instance}")
        return result_instance
    except Exception as e:
        logger.error(f"Error parsing result for '{api_name}': {e}")
        raise ValueError(f"Invalid result from '{api_name}': {e}")
generate_documentation()

Generate documentation for all functions managed by this toolkit based on their schemas.

Returns:

Name Type Description
str str

Generated documentation as a markdown string.

Source code in src/aeiva/tool/toolkit/toolkit.py
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def generate_documentation(self) -> str:
    """
    Generate documentation for all functions managed by this toolkit based on their schemas.

    Returns:
        str: Generated documentation as a markdown string.
    """
    doc = f"# Toolkit: {self.toolkit_name}\n\n"
    for api_name, tool in self.tools.items():
        schema = self.tool_schemas.get(api_name, {})
        if not schema:
            continue
        doc += f"## Function: {api_name}\n\n"
        doc += f"**Description:** {schema.get('description', 'No description provided.')}\n\n"
        doc += "### Parameters:\n\n"
        parameters = schema.get("parameters", {})
        for prop, details in parameters.get("properties", {}).items():
            req = " (required)" if prop in parameters.get("required", []) else ""
            description = details.get("description", "")
            default = f" (default: {details.get('default')})" if "default" in details else ""
            doc += f"- **{prop}** ({details.get('type', 'any')}): {description}{default}{req}\n"
        doc += "\n### Example:\n\n"
        example = schema.get("example", "No example provided.")
        if isinstance(example, dict):
            example = json.dumps(example, indent=4)
        doc += f"```json\n{example}\n```\n\n"
    return doc
init_shared_resources()

Initialize shared resources required by the toolkit. Override this method in subclasses if needed.

Source code in src/aeiva/tool/toolkit/toolkit.py
124
125
126
127
128
129
130
131
132
133
def init_shared_resources(self):
    """
    Initialize shared resources required by the toolkit.
    Override this method in subclasses if needed.
    """
    logger.info("Initializing shared resources.")
    # Placeholder for initializing shared resources like databases, servers, etc.
    # Example:
    # self.shared_resources = initialize_database_connection()
    pass
load_pydantic_models_for_all_tools()

Load Pydantic models (Params and Result) for all tools for validation.

Source code in src/aeiva/tool/toolkit/toolkit.py
78
79
80
81
82
83
84
85
86
87
88
89
90
def load_pydantic_models_for_all_tools(self):
    """
    Load Pydantic models (Params and Result) for all tools for validation.
    """
    logger.info("Loading Pydantic models for all tools.")
    for tool_name in self.tool_names:
        try:
            param_model, result_model = self.load_pydantic_models_for_tool(tool_name)
            self.tool_models[tool_name] = (param_model, result_model)
            logger.debug(f"Loaded models for tool '{tool_name}': Params={param_model}, Result={result_model}")
        except Exception as e:
            logger.error(f"Failed to load models for tool '{tool_name}': {e}")
            raise
load_pydantic_models_for_tool(api_name)

Load the parameter and result Pydantic models for the given API.

Parameters:

Name Type Description Default
api_name str

The name of the API function.

required

Returns:

Type Description
Tuple[Type[BaseModel], Type[BaseModel]]

Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.

Raises:

Type Description
ValueError

If models cannot be loaded.

Source code in src/aeiva/tool/toolkit/toolkit.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def load_pydantic_models_for_tool(self, api_name: str) -> Tuple[Type[BaseModel], Type[BaseModel]]:
    """
    Load the parameter and result Pydantic models for the given API.

    Args:
        api_name (str): The name of the API function.

    Returns:
        Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.

    Raises:
        ValueError: If models cannot be loaded.
    """
    module_path = f"aeiva.tool.api.{api_name}.model"  # Adjusted as per user's path
    try:
        models_module = importlib.import_module(module_path)
        param_model_class = getattr(models_module, f"{snake_to_camel(api_name)}Params", None)
        result_model_class = getattr(models_module, f"{snake_to_camel(api_name)}Result", None)
        if not (param_model_class and issubclass(param_model_class, BaseModel)):
            logger.error(f"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.")
            raise ValueError(f"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.")
        if not (result_model_class and issubclass(result_model_class, BaseModel)):
            logger.error(f"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.")
            raise ValueError(f"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.")
        return param_model_class, result_model_class
    except ImportError as e:
        logger.error(f"Error importing models from '{module_path}': {e}")
        raise ImportError(f"Error importing models from '{module_path}': {e}")
    except AttributeError as e:
        logger.error(f"Error accessing model classes in '{module_path}': {e}")
        raise ValueError(f"Error accessing model classes in '{module_path}': {e}")
perform_security_checks(param_instance)

Perform security checks on parameters that require sanitization.

Parameters:

Name Type Description Default
param_instance BaseModel

The validated parameter instance.

required

Returns:

Name Type Description
BaseModel BaseModel

The sanitized parameter instance.

Raises:

Type Description
ValueError

If sanitization fails for any field or if config is required but not provided.

Source code in src/aeiva/tool/toolkit/toolkit.py
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def perform_security_checks(self, param_instance: BaseModel) -> BaseModel:
    """
    Perform security checks on parameters that require sanitization.

    Args:
        param_instance (BaseModel): The validated parameter instance.

    Returns:
        BaseModel: The sanitized parameter instance.

    Raises:
        ValueError: If sanitization fails for any field or if config is required but not provided.
    """
    sanitized_params = param_instance.dict()

    for field_name, field in param_instance.__fields__.items():
        sanitize = field.field_info.extra.get('sanitize', False)
        if not sanitize:
            continue  # Skip fields that do not require sanitization

        field_type = field.type_
        origin = get_origin(field_type)
        args = get_args(field_type)

        # Determine if the field is a string type or contains string types
        is_string_field = False

        if field_type == str:
            is_string_field = True
        elif origin is Union and str in args:
            is_string_field = True
        elif origin is list and len(args) == 1 and args[0] == str:
            is_string_field = True
        elif origin is Optional and str in args:
            is_string_field = True
        # Add more conditions here if there are other complex types containing strings

        if is_string_field:
            original_value = sanitized_params.get(field_name)
            if original_value is None:
                continue  # Skip if the field value is None

            if self.config is None:
                logger.error(
                    f"Configuration is required to sanitize field '{field_name}', "
                    f"but config is not provided."
                )
                raise ValueError(
                    f"Configuration is required to sanitize field '{field_name}', "
                    f"but config is not provided."
                )

            try:
                # If the field is a list of strings, sanitize each path individually
                if origin is list and len(args) == 1 and args[0] == str:
                    if not isinstance(original_value, list):
                        logger.error(
                            f"Expected a list for field '{field_name}', "
                            f"got {type(original_value)}."
                        )
                        raise ValueError(
                            f"Expected a list for field '{field_name}'."
                        )
                    sanitized_list = []
                    for idx, item in enumerate(original_value):
                        sanitized_item = sanitize_file_path(item, self.config)
                        sanitized_list.append(sanitized_item)
                        logger.debug(
                            f"Sanitized '{field_name}[{idx}]': '{item}' -> '{sanitized_item}'"
                        )
                    sanitized_params[field_name] = sanitized_list
                else:
                    # Sanitize single string path
                    sanitized_path = sanitize_file_path(original_value, self.config)
                    sanitized_params[field_name] = sanitized_path
                    logger.debug(
                        f"Sanitized '{field_name}': '{original_value}' -> '{sanitized_path}'"
                    )
            except ValueError as ve:
                logger.error(
                    f"Sanitization failed for field '{field_name}': {ve}"
                )
                raise

    # Create a new instance of the parameter model with sanitized parameters
    sanitized_instance = param_instance.copy(update=sanitized_params)

    return sanitized_instance
setup()

Setup the toolkit by loading tools, their schemas, and initializing shared resources.

Source code in src/aeiva/tool/toolkit/toolkit.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def setup(self):
    """
    Setup the toolkit by loading tools, their schemas, and initializing shared resources.
    """
    logger.info(f"Setting up toolkit '{self.toolkit_name}'.")

    # Load tools and their schemas
    for tool_name in self.tool_names:
        tool = Tool(api_name=tool_name)
        self.tools[tool_name] = tool
        schema = tool.load_tool_schema(tool_name)
        self.tool_schemas[tool_name] = schema
        logger.debug(f"Loaded schema for tool '{tool_name}': {schema}")

    # Load Pydantic models for validation
    self.load_pydantic_models_for_all_tools()

    # Initialize shared resources
    self.init_shared_resources()
teardown()

Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.

Source code in src/aeiva/tool/toolkit/toolkit.py
135
136
137
138
139
140
141
142
143
144
145
146
147
def teardown(self):
    """
    Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.
    """
    logger.info(f"Tearing down toolkit '{self.toolkit_name}'.")

    # Clean up shared resources
    self.teardown_shared_resources()

    # Clear loaded data
    self.tools.clear()
    self.tool_schemas.clear()
    self.tool_models.clear()
teardown_shared_resources()

Teardown shared resources. Override this method in subclasses if needed.

Source code in src/aeiva/tool/toolkit/toolkit.py
149
150
151
152
153
154
155
156
157
158
159
def teardown_shared_resources(self):
    """
    Teardown shared resources.
    Override this method in subclasses if needed.
    """
    logger.info("Tearing down shared resources.")
    # Placeholder for tearing down shared resources
    # Example:
    # if self.shared_resources:
    #     self.shared_resources.close()
    pass

toolkit_config

ToolkitConfig dataclass

Bases: BaseConfig

Configuration for the Toolkit.

Source code in src/aeiva/tool/toolkit/toolkit_config.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@dataclass
class ToolkitConfig(BaseConfig):
    """
    Configuration for the Toolkit.
    """

    allowed_directories: List[str] = field(
        default_factory=lambda: ["/tmp/", "/home/user/allowed_directory/"],
        metadata={"help": "Directories that tools are allowed to access."}
    )
    # Mapping from OS usernames to roles
    user_role_mapping: Dict[str, str] = field(
        default_factory=lambda: {
            "admin_user": "admin",
            "regular_user": "user"
            # Add more user-role mappings as needed
        },
        metadata={"help": "Mapping of OS usernames to their roles."}
    )
    # Define permissions for each role
    role_permissions: Dict[str, List[str]] = field(
        default_factory=lambda: {
            "admin": ["delete_file", "view_file", "create_file"],
            "user": ["view_file", "create_file"]
        },
        metadata={"help": "Mapping of roles to allowed API functions."}
    )

web_toolkit

WebToolkit

Bases: Toolkit

A toolkit for interacting with web pages.

Source code in src/aeiva/tool/toolkit/web_toolkit.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class WebToolkit(Toolkit):
    """
    A toolkit for interacting with web pages.
    """

    def __init__(self, config=None):
        super().__init__(
            name="WebToolkit",
            tool_names=[
                "click_webpage_element",
                "crawl",
                "execute_js_script_on_webpage",
                "get_webpage_details",
                "get_webpage_elements",
                "navigate_browser_history",
                "navigate_to_webpage",
                "refresh_webpage",
                "scrape",
                "scroll_webpage",
                "type_text_in_webpage_element",
                "web_search"
            ],
            config=config
        )

trainer

pl_trainer

LightningTrainer

Bases: LightningModule

Source code in src/aeiva/trainer/pl_trainer.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class LightningTrainer(pl.LightningModule):
    def __init__(self, model, tokenizer, config):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.config = config

    def forward(self, batch):
        outputs = self.model(batch)
        return outputs

    def training_step(self, batch, batch_idx):
        outputs = self(batch)
        loss = outputs.loss
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        outputs = self(batch)
        loss = outputs.loss
        return {"loss": loss}

    def test_step(self, batch, batch_idx):
        outputs = self(batch)
        loss = outputs.loss
        return {"loss": loss}

    def configure_optimizers(self):
        """
        Function to prepare the optimizer and learning rate scheduler for model training.
        This function separates model parameters into two categories: parameters that will experience weight decay, 
        and those that will not (e.g., bias and layernorm weights). 

        Returns:
            Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.
        """

        # List of module types that will be subjected to weight decay.
        whitelist_weight_modules = (torch.nn.Linear, ) 

        # List of module types that will not be subjected to weight decay.
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)

        # Parameter sets for decay and no decay.
        decay = set()
        no_decay = set()

        # Populate the decay and no_decay sets. 
        # Loop over all modules to get module name (mn) and module (m).
        # !!!! revise later.
        # for mn, m in self.model.named_modules():
        #     for pn, p in m.named_parameters():
        #         fpn = '%s.%s' % (mn, pn) if mn else pn 

        #         if 'bias' in pn:
        #             no_decay.add(fpn)
        #         elif 'weight' in pn:
        #             decay.add(fpn)

        param_dict = {pn: p for pn, p in self.model.named_parameters()}

        for mn, m in self.model.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to...
                # Adding new condition to check for the 'class_embedding' and 'logit_scale' parameters
                if pn.endswith('bias') or 'class_embedding' in pn or 'logit_scale' in pn:
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)
        for pn, p in param_dict.items():
            if pn not in no_decay:
                decay.add(pn)


        # # After this loop, print out all parameters in the intersection of decay and no_decay:
        # print("decay: ", decay)
        # print("no_decay: ", no_decay)
        # print("intersection: ", decay.intersection(no_decay))

        # print("difference: ", param_dict.keys() - (decay | no_decay))


        # # 'lm_head.weight' is tied to 'model.embed_tokens.weight', so it should not be decayed. 
        # # This ensures that the same tensor is not optimized in different ways.
        # decay.remove('llm.lm_head.weight')

        # Validate that we considered every parameter.
        param_dict = {pn: p for pn, p in self.model.named_parameters()}
        assert len(decay & no_decay) == 0, "Some parameters are in both decay and no_decay sets!"
        assert len(param_dict.keys() - (decay | no_decay)) == 0, "Some parameters are in neither decay nor no_decay sets!"

        # Create the PyTorch optimizer object.
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
        use_fused = (self.config.device == 'cuda') and (
            'fused' in inspect.signature(torch.optim.AdamW).parameters)
        print(f"using fused AdamW: {use_fused}")
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(
            optim_groups, lr=self.config.learning_rate, betas=(self.config.adam_beta1, self.config.adam_beta2), **extra_args)

        # Prepare learning rate scheduler.
        total_steps = self.config.max_steps
        pct_start = self.config.warmup_steps / total_steps
        final_div_factor = self.config.learning_rate / self.config.min_lr

        scheduler = {
            'scheduler': torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=self.config.learning_rate,
                total_steps=total_steps,
                pct_start=pct_start,
                final_div_factor=final_div_factor,
                div_factor=1.0,  # No additional scaling for the initial learning rate
                anneal_strategy='cos',  # Use cosine annealing
                cycle_momentum=False,  # Disable momentum cycling
            ),
            'interval': 'step',
            'frequency': 1
        }

        return [optimizer], [scheduler]


    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.model.parameters())
        if non_embedding:
            embedding_params = sum(p.numel() for m in self.model.modules() if isinstance(m, nn.Embedding) for p in m.parameters())
            n_params -= embedding_params
        return n_params

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt)  # per second
        flops_promised = 312e12  # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu
configure_optimizers()

Function to prepare the optimizer and learning rate scheduler for model training. This function separates model parameters into two categories: parameters that will experience weight decay, and those that will not (e.g., bias and layernorm weights).

Returns:

Type Description

Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.

Source code in src/aeiva/trainer/pl_trainer.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def configure_optimizers(self):
    """
    Function to prepare the optimizer and learning rate scheduler for model training.
    This function separates model parameters into two categories: parameters that will experience weight decay, 
    and those that will not (e.g., bias and layernorm weights). 

    Returns:
        Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.
    """

    # List of module types that will be subjected to weight decay.
    whitelist_weight_modules = (torch.nn.Linear, ) 

    # List of module types that will not be subjected to weight decay.
    blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)

    # Parameter sets for decay and no decay.
    decay = set()
    no_decay = set()

    # Populate the decay and no_decay sets. 
    # Loop over all modules to get module name (mn) and module (m).
    # !!!! revise later.
    # for mn, m in self.model.named_modules():
    #     for pn, p in m.named_parameters():
    #         fpn = '%s.%s' % (mn, pn) if mn else pn 

    #         if 'bias' in pn:
    #             no_decay.add(fpn)
    #         elif 'weight' in pn:
    #             decay.add(fpn)

    param_dict = {pn: p for pn, p in self.model.named_parameters()}

    for mn, m in self.model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
            # random note: because named_modules and named_parameters are recursive
            # we will see the same tensors p many many times. but doing it this way
            # allows us to know which parent module any tensor p belongs to...
            # Adding new condition to check for the 'class_embedding' and 'logit_scale' parameters
            if pn.endswith('bias') or 'class_embedding' in pn or 'logit_scale' in pn:
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                no_decay.add(fpn)
    for pn, p in param_dict.items():
        if pn not in no_decay:
            decay.add(pn)


    # # After this loop, print out all parameters in the intersection of decay and no_decay:
    # print("decay: ", decay)
    # print("no_decay: ", no_decay)
    # print("intersection: ", decay.intersection(no_decay))

    # print("difference: ", param_dict.keys() - (decay | no_decay))


    # # 'lm_head.weight' is tied to 'model.embed_tokens.weight', so it should not be decayed. 
    # # This ensures that the same tensor is not optimized in different ways.
    # decay.remove('llm.lm_head.weight')

    # Validate that we considered every parameter.
    param_dict = {pn: p for pn, p in self.model.named_parameters()}
    assert len(decay & no_decay) == 0, "Some parameters are in both decay and no_decay sets!"
    assert len(param_dict.keys() - (decay | no_decay)) == 0, "Some parameters are in neither decay nor no_decay sets!"

    # Create the PyTorch optimizer object.
    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.config.weight_decay},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]
    # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
    use_fused = (self.config.device == 'cuda') and (
        'fused' in inspect.signature(torch.optim.AdamW).parameters)
    print(f"using fused AdamW: {use_fused}")
    extra_args = dict(fused=True) if use_fused else dict()
    optimizer = torch.optim.AdamW(
        optim_groups, lr=self.config.learning_rate, betas=(self.config.adam_beta1, self.config.adam_beta2), **extra_args)

    # Prepare learning rate scheduler.
    total_steps = self.config.max_steps
    pct_start = self.config.warmup_steps / total_steps
    final_div_factor = self.config.learning_rate / self.config.min_lr

    scheduler = {
        'scheduler': torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.config.learning_rate,
            total_steps=total_steps,
            pct_start=pct_start,
            final_div_factor=final_div_factor,
            div_factor=1.0,  # No additional scaling for the initial learning rate
            anneal_strategy='cos',  # Use cosine annealing
            cycle_momentum=False,  # Disable momentum cycling
        ),
        'interval': 'step',
        'frequency': 1
    }

    return [optimizer], [scheduler]
estimate_mfu(fwdbwd_per_iter, dt)

estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS

Source code in src/aeiva/trainer/pl_trainer.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def estimate_mfu(self, fwdbwd_per_iter, dt):
    """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
    # first estimate the number of flops we do per iteration.
    # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
    N = self.get_num_params()
    cfg = self.config
    L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
    flops_per_token = 6*N + 12*L*H*Q*T
    flops_per_fwdbwd = flops_per_token * T
    flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
    # express our flops throughput as ratio of A100 bfloat16 peak flops
    flops_achieved = flops_per_iter * (1.0/dt)  # per second
    flops_promised = 312e12  # A100 GPU bfloat16 peak flops is 312 TFLOPS
    mfu = flops_achieved / flops_promised
    return mfu
get_num_params(non_embedding=True)

Return the number of parameters in the model. For non-embedding count (default), the position embeddings get subtracted. The token embeddings would too, except due to the parameter sharing these params are actually used as weights in the final layer, so we include them.

Source code in src/aeiva/trainer/pl_trainer.py
137
138
139
140
141
142
143
144
145
146
147
148
def get_num_params(self, non_embedding=True):
    """
    Return the number of parameters in the model.
    For non-embedding count (default), the position embeddings get subtracted.
    The token embeddings would too, except due to the parameter sharing these
    params are actually used as weights in the final layer, so we include them.
    """
    n_params = sum(p.numel() for p in self.model.parameters())
    if non_embedding:
        embedding_params = sum(p.numel() for m in self.model.modules() if isinstance(m, nn.Embedding) for p in m.parameters())
        n_params -= embedding_params
    return n_params

util

file_utils

from_json_or_yaml(filepath)

Load configuration from a JSON or YAML file based on the file extension.

Parameters:

Name Type Description Default
filepath str

The path to the configuration file.

required

Returns:

Name Type Description
dict dict

The configuration dictionary.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the file extension is unsupported or if parsing fails.

Source code in src/aeiva/util/file_utils.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def from_json_or_yaml(filepath: str) -> dict:
    """
    Load configuration from a JSON or YAML file based on the file extension.

    Args:
        filepath (str): The path to the configuration file.

    Returns:
        dict: The configuration dictionary.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the file extension is unsupported or if parsing fails.
    """
    if not os.path.exists(filepath):
        logger.error(f"Configuration file not found at path: {filepath}")
        raise FileNotFoundError(f"Configuration file not found at path: {filepath}")

    _, ext = os.path.splitext(filepath)
    ext = ext.lower()

    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            if ext == '.json':
                config = json.load(f)
                logger.info(f"Loaded JSON configuration from {filepath}.")
                return config
            elif ext in ['.yaml', '.yml']:
                config = yaml.safe_load(f)
                logger.info(f"Loaded YAML configuration from {filepath}.")
                return config
            else:
                logger.error(f"Unsupported configuration file format: {ext}")
                raise ValueError(f"Unsupported configuration file format: {ext}")
    except (json.JSONDecodeError, yaml.YAMLError) as e:
        logger.error(f"Error parsing configuration file '{filepath}': {e}")
        raise ValueError(f"Error parsing configuration file '{filepath}': {e}")

os_utils

get_os_user()

Retrieve the current OS username.

Returns:

Name Type Description
str str

The current OS user's name.

Source code in src/aeiva/util/os_utils.py
 4
 5
 6
 7
 8
 9
10
11
def get_os_user() -> str:
    """
    Retrieve the current OS username.

    Returns:
        str: The current OS user's name.
    """
    return getpass.getuser()

path_utils

get_package_root(package_name)

Obtain the root directory of a given package.

Parameters:

Name Type Description Default
package_name str

The name of the package.

required

Returns:

Name Type Description
str str

The absolute path to the package root directory.

Source code in src/aeiva/util/path_utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def get_package_root(package_name: str) -> str:
    """
    Obtain the root directory of a given package.

    Args:
        package_name (str): The name of the package.

    Returns:
        str: The absolute path to the package root directory.
    """
    spec = importlib.util.find_spec(package_name)
    if spec is None or spec.origin is None:
        raise ImportError(f"Cannot find package '{package_name}'")
    package_path = os.path.dirname(os.path.abspath(spec.origin))
    return package_path

get_user_home_path()

Retrieves the home directory of the current user across different platforms.

Supported Platforms
  • Windows
  • macOS
  • Linux
  • iOS (best-effort)
  • Android (best-effort)

Returns:

Name Type Description
Path Path

A Path object representing the user's home directory.

Raises:

Type Description
EnvironmentError

If the home directory cannot be determined.

Source code in src/aeiva/util/path_utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def get_user_home_path() -> Path:
    """
    Retrieves the home directory of the current user across different platforms.

    Supported Platforms:
        - Windows
        - macOS
        - Linux
        - iOS (best-effort)
        - Android (best-effort)

    Returns:
        Path: A `Path` object representing the user's home directory.

    Raises:
        EnvironmentError: If the home directory cannot be determined.
    """
    system = platform.system()
    logger.info(f"Detected operating system: {system}")

    try:
        if system == "Windows":
            # Windows: Use USERPROFILE or combine HOMEDRIVE and HOMEPATH
            home = os.environ.get('USERPROFILE') or os.path.join(os.environ.get('HOMEDRIVE', ''), os.environ.get('HOMEPATH', ''))
            logger.debug(f"Windows home directory: {home}")
        elif system in ["Linux", "Darwin"]:  # Darwin is macOS
            # Unix-like systems: Use expanduser
            home = os.path.expanduser("~")
            logger.debug(f"Unix-like home directory: {home}")
        elif system == "Java":  # Potentially Android (e.g., running via Jython or similar)
            # Android typically uses /data/user/0/<package_name>/ or /sdcard/
            # However, accessing these paths may require specific permissions
            # Here, we attempt to use the HOME environment variable
            home = os.environ.get('HOME') or '/sdcard/'
            logger.debug(f"Android home directory (best-effort): {home}")
        elif system == "iOS":
            # iOS applications are sandboxed; home directory is typically the app's sandbox
            # Accessing it might require specific APIs or configurations
            # Here, we return the current working directory as a placeholder
            home = Path.cwd()
            logger.debug(f"iOS home directory (best-effort): {home}")
        else:
            # Fallback for unknown systems
            home = os.path.expanduser("~")
            logger.warning(f"Unknown system '{system}'. Falling back to expanduser: {home}")

        if home and os.path.isdir(home):
            return Path(home)
        else:
            raise EnvironmentError("Determined home directory does not exist or is not a directory.")

    except Exception as e:
        logger.error(f"Failed to determine the user's home directory: {e}")
        raise EnvironmentError("Cannot determine the user's home directory.") from e

snake_to_camel(snake_str)

Convert a snake_case string to CamelCase.

Parameters:

Name Type Description Default
snake_str str

The snake_case string.

required

Returns:

Name Type Description
str str

The CamelCase string.

Source code in src/aeiva/util/path_utils.py
85
86
87
88
89
90
91
92
93
94
95
96
97
def snake_to_camel(snake_str: str) -> str:
    """
    Convert a snake_case string to CamelCase.

    Args:
        snake_str (str): The snake_case string.

    Returns:
        str: The CamelCase string.
    """
    components = snake_str.split('_')
    # Capitalize the first letter of each component
    return ''.join(x.title() for x in components)

token_utils

pad_or_truncate_tokens(tokens, max_length, pad_token_id)

This function aims to pad or truncate tokens to max_length.

Parameters:

Name Type Description Default
tokens list

the list of tokens.

required
max_length int

the max length of tokens.

required
pad_token_id int

the id of pad token.

required

Returns:

Name Type Description
tokens list

the list of tokens after padding or truncating.

Source code in src/aeiva/util/token_utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def pad_or_truncate_tokens(tokens, max_length, pad_token_id):
    """ This function aims to pad or truncate tokens to max_length.

    Args:
        tokens (list): the list of tokens.
        max_length (int): the max length of tokens.
        pad_token_id (int): the id of pad token.

    Returns:
        tokens (list): the list of tokens after padding or truncating.
    """
    if len(tokens) > max_length:
        tokens = tokens[:max_length]
    elif len(tokens) < max_length:
        tokens = tokens + [pad_token_id] * (max_length - len(tokens))
    return tokens