from typing import Any, Dict, Union import ray from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage class KVStorageActor: def __init__(self, backend: str, working_dir: str, namespace: str): if backend == "json_kv": from graphgen.models import JsonKVStorage self.kv = JsonKVStorage(working_dir, namespace) elif backend == "rocksdb": from graphgen.models import RocksDBKVStorage self.kv = RocksDBKVStorage(working_dir, namespace) else: raise ValueError(f"Unknown KV backend: {backend}") def data(self) -> Dict[str, Dict]: return self.kv.data def all_keys(self) -> list[str]: return self.kv.all_keys() def index_done_callback(self): return self.kv.index_done_callback() def get_by_id(self, id: str) -> Dict: return self.kv.get_by_id(id) def get_by_ids(self, ids: list[str], fields=None) -> list: return self.kv.get_by_ids(ids, fields) def get_all(self) -> Dict[str, Dict]: return self.kv.get_all() def filter_keys(self, data: list[str]) -> set[str]: return self.kv.filter_keys(data) def upsert(self, data: dict) -> dict: return self.kv.upsert(data) def drop(self): return self.kv.drop() def reload(self): return self.kv.reload() def ready(self) -> bool: return True class GraphStorageActor: def __init__(self, backend: str, working_dir: str, namespace: str): if backend == "networkx": from graphgen.models import NetworkXStorage self.graph = NetworkXStorage(working_dir, namespace) elif backend == "kuzu": from graphgen.models import KuzuStorage self.graph = KuzuStorage(working_dir, namespace) else: raise ValueError(f"Unknown Graph backend: {backend}") def index_done_callback(self): return self.graph.index_done_callback() def has_node(self, node_id: str) -> bool: return self.graph.has_node(node_id) def has_edge(self, source_node_id: str, target_node_id: str): return self.graph.has_edge(source_node_id, target_node_id) def node_degree(self, node_id: str) -> int: return self.graph.node_degree(node_id) def edge_degree(self, src_id: str, tgt_id: str) -> int: return self.graph.edge_degree(src_id, tgt_id) def get_node(self, node_id: str) -> Any: return self.graph.get_node(node_id) def update_node(self, node_id: str, node_data: dict[str, str]): return self.graph.update_node(node_id, node_data) def get_all_nodes(self) -> Any: return self.graph.get_all_nodes() def get_edge(self, source_node_id: str, target_node_id: str): return self.graph.get_edge(source_node_id, target_node_id) def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): return self.graph.update_edge(source_node_id, target_node_id, edge_data) def get_all_edges(self) -> Any: return self.graph.get_all_edges() def get_node_edges(self, source_node_id: str) -> Any: return self.graph.get_node_edges(source_node_id) def upsert_node(self, node_id: str, node_data: dict[str, str]): return self.graph.upsert_node(node_id, node_data) def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): return self.graph.upsert_edge(source_node_id, target_node_id, edge_data) def delete_node(self, node_id: str): return self.graph.delete_node(node_id) def reload(self): return self.graph.reload() def ready(self) -> bool: return True class RemoteKVStorageProxy(BaseKVStorage): def __init__(self, actor_handle: ray.actor.ActorHandle): super().__init__() self.actor = actor_handle def data(self) -> Dict[str, Any]: return ray.get(self.actor.data.remote()) def all_keys(self) -> list[str]: return ray.get(self.actor.all_keys.remote()) def index_done_callback(self): return ray.get(self.actor.index_done_callback.remote()) def get_by_id(self, id: str) -> Union[Any, None]: return ray.get(self.actor.get_by_id.remote(id)) def get_by_ids(self, ids: list[str], fields=None) -> list[Any]: return ray.get(self.actor.get_by_ids.remote(ids, fields)) def get_all(self) -> Dict[str, Any]: return ray.get(self.actor.get_all.remote()) def filter_keys(self, data: list[str]) -> set[str]: return ray.get(self.actor.filter_keys.remote(data)) def upsert(self, data: Dict[str, Any]): return ray.get(self.actor.upsert.remote(data)) def drop(self): return ray.get(self.actor.drop.remote()) def reload(self): return ray.get(self.actor.reload.remote()) class RemoteGraphStorageProxy(BaseGraphStorage): def __init__(self, actor_handle: ray.actor.ActorHandle): super().__init__() self.actor = actor_handle def index_done_callback(self): return ray.get(self.actor.index_done_callback.remote()) def has_node(self, node_id: str) -> bool: return ray.get(self.actor.has_node.remote(node_id)) def has_edge(self, source_node_id: str, target_node_id: str): return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id)) def node_degree(self, node_id: str) -> int: return ray.get(self.actor.node_degree.remote(node_id)) def edge_degree(self, src_id: str, tgt_id: str) -> int: return ray.get(self.actor.edge_degree.remote(src_id, tgt_id)) def get_node(self, node_id: str) -> Any: return ray.get(self.actor.get_node.remote(node_id)) def update_node(self, node_id: str, node_data: dict[str, str]): return ray.get(self.actor.update_node.remote(node_id, node_data)) def get_all_nodes(self) -> Any: return ray.get(self.actor.get_all_nodes.remote()) def get_edge(self, source_node_id: str, target_node_id: str): return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id)) def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): return ray.get( self.actor.update_edge.remote(source_node_id, target_node_id, edge_data) ) def get_all_edges(self) -> Any: return ray.get(self.actor.get_all_edges.remote()) def get_node_edges(self, source_node_id: str) -> Any: return ray.get(self.actor.get_node_edges.remote(source_node_id)) def upsert_node(self, node_id: str, node_data: dict[str, str]): return ray.get(self.actor.upsert_node.remote(node_id, node_data)) def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): return ray.get( self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data) ) def delete_node(self, node_id: str): return ray.get(self.actor.delete_node.remote(node_id)) def reload(self): return ray.get(self.actor.reload.remote()) class StorageFactory: """ Factory class to create storage instances based on backend. """ @staticmethod def create_storage(backend: str, working_dir: str, namespace: str): if backend in ["json_kv", "rocksdb"]: actor_name = f"Actor_KV_{namespace}" actor_class = KVStorageActor proxy_class = RemoteKVStorageProxy elif backend in ["networkx", "kuzu"]: actor_name = f"Actor_Graph_{namespace}" actor_class = GraphStorageActor proxy_class = RemoteGraphStorageProxy else: raise ValueError(f"Unknown storage backend: {backend}") try: actor_handle = ray.get_actor(actor_name) except ValueError: actor_handle = ray.remote(actor_class).options( name=actor_name, get_if_exists=True, ).remote(backend, working_dir, namespace) ray.get(actor_handle.ready.remote()) return proxy_class(actor_handle) def init_storage(backend: str, working_dir: str, namespace: str): return StorageFactory.create_storage(backend, working_dir, namespace)