Source code for topchef.models.service_list

from collections.abc import AsyncIterator as CollectionsAsyncIterator
from collections.abc import Awaitable
from typing import Union, Iterator, Sequence, AsyncIterator
from uuid import UUID

from sqlalchemy.orm import Session

from topchef.database.models import Service as DatabaseService
from topchef.json_type import JSON_TYPE as JSON
from topchef.models.interfaces.service_list import ServiceList as IServiceList
from topchef.models.service import Service


[docs]class ServiceList(IServiceList): """ Implements a means of getting services from a relational DB back end """
[docs] def __init__(self, session: Session) -> None: """ :param session: The database session to use for getting services """ self.session = session
def __getitem__(self, service_id: UUID) -> Service: db_model = self._get_db_model_by_id(self.session, service_id) return Service(db_model) def __setitem__(self, service_id: UUID, service: Service) -> None: db_model = self._get_db_model_by_id(self.session, service_id) db_model.is_service_available = service.is_service_available db_model.name = service.name db_model.description = service.description self.session.add(db_model) def __delitem__(self, service_id: UUID) -> None: db_model = self._get_db_model_by_id(self.session, service_id) for job in db_model.jobs: self.session.delete(job) self.session.delete(db_model) def __contains__( self, service_or_service_id: Union[UUID, Service] ) -> bool: if isinstance(service_or_service_id, Service): is_in_collection = self._check_service_membership( service_or_service_id ) else: is_in_collection = self._check_id_membership( service_or_service_id ) return is_in_collection def __len__(self) -> int: return self.session.query(DatabaseService).count() def __iter__(self) -> Iterator[Service]: return ( Service(db_service) for db_service in self.session.query(DatabaseService).all() ) def __aiter__(self) -> AsyncIterator[Service]: services = self.session.query(DatabaseService).all() # type: list return self._AsynchronousServicesIterator(services) def new( self, name: str, description: str, registration_schema: JSON, result_schema: JSON) -> Service: service = DatabaseService.new( name, description, registration_schema, result_schema ) self.session.add(service) return Service(service) @staticmethod def _get_db_model_by_id( session: Session, service_id: UUID ) -> DatabaseService: db_model = session.query( DatabaseService ).filter_by( id=service_id ).first() if db_model is None: raise KeyError('A model with that ID does not exist') return db_model def _check_service_membership(self, service: Service): number_of_matches = self.session.query(DatabaseService).filter_by( id=service.id).count() return bool(number_of_matches) def _check_id_membership(self, service_id: UUID) -> bool: number_of_matches = self.session.query(DatabaseService).filter_by( id=service_id ).count() return bool(number_of_matches) class _AsynchronousServicesIterator( CollectionsAsyncIterator ): def __init__(self, services: Sequence[Service]): self.services = services self._last_served_index = 0 def __len__(self): return len(self.services) async def __anext__(self) -> Awaitable: if self._last_served_index < len(self.services): service = self._AsyncServiceFuture( Service(self.services[self._last_served_index]) ) self._last_served_index += 1 else: raise StopAsyncIteration() return service class _AsyncServiceFuture(Awaitable): def __init__(self, service: Service): self.service = service def __await__(self): return self.service