Source code for carepoint.db.carepoint

# -*- coding: utf-8 -*-
# Copyright 2016-TODAY LasLabs Inc.
# License MIT (https://opensource.org/licenses/MIT).

import imp
import operator
import os
import urllib2

from contextlib import contextmanager

from sqlalchemy import bindparam, or_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text


from .db import Db
from smb.SMBHandler import SMBHandler


Base = declarative_base()
Base.get = lambda s, k, v=None: getattr(s, k, v)
Base.__getitem__ = lambda s, k, v=None: getattr(s, k, v)
Base.__setitem__ = lambda s, k, v: setattr(s, k, v)


models, env, dbs = {}, {}, {}


[docs]class Carepoint(dict): """ Base CarePoint db connector object """ BASE = Base DEFAULT_DB = 'cph' # Default path to search for models - change with register_model_dir model_path = os.path.join(os.path.dirname(__file__), '..', 'models') FILTERS = { '>=': operator.ge, '>': operator.gt, '<=': operator.le, '<': operator.lt, '=': operator.eq, '==': operator.eq, } def __init__( self, server, user, passwd, smb_user=None, smb_passwd=None, db_args=None, **engine_args ): """ It initializes new Carepoint object Args: server (str): IP or Hostname to database user (str): Username for database passwd (str): Password for database smb_user (str): Username to use for SMB connection, ``None`` to use the database user smd_passwd (str): Password to use for the SMB connection, ``None`` to use the database password db_args (dict): Dictionary of arguments to send during initial db creation **engine_args (mixed): Kwargs to pass to ``create_engine`` """ super(Carepoint, self).__init__() global dbs self.env = {} self.dbs = dbs self.iter_refresh = False params = { 'user': user, 'passwd': passwd, 'server': server, 'db': 'cph', } if db_args is not None: params.update(db_args) if engine_args: params.update(engine_args) if smb_user is None: self.smb_creds = { 'user': user, 'passwd': passwd, } else: self.smb_creds = { 'user': smb_user, 'passwd': smb_passwd, } self.db_params = params self._init_env(False) def _init_env(self, clear=False): """ It initializes the global db and environments Params: clear: (bool) True to clear the global session """ if clear: self.dbs.clear() # @TODO: Lazy load, once other dbs needed if not self.dbs.get('cph'): self.dbs['cph'] = Db(**self.db_params) if not self.env.get('cph'): self.env['cph'] = sessionmaker( autocommit=False, autoflush=False, bind=self.dbs['cph'], expire_on_commit=True, ) def _get_model_session(self, model_obj): """ It yields a session for the model_obj """ return self._get_session(model_obj.__dbname__) @contextmanager def _get_session(self, db_name): session = self.env[db_name]() try: yield session session.commit() except: session.rollback() raise finally: session.close() @property def _smb_prefix(self): """ Return URI prefix for SMB share """ return 'smb://{user}:{passwd}@'.format(**self.smb_creds)
[docs] def get_file(self, path): """ Return a file-like object for the SMB path Args: path: :type:`str` SMB path to fetch Returns: :type:`file` File interface object representing remote resource """ opener = urllib2.build_opener(SMBHandler) return opener.open('%s%s' % (self._smb_prefix, path))
[docs] def send_file(self, path, file_obj): """ Send a file-like object to the SMB path Args: path: :type:`str` SMB path to fetch file_obj: :type:`file` File interface object to send to server Returns: :type:`bool` Success """ with urllib2.build_opener(SMBHandler) as opener: opener.open('%s%s' % (self._smb_prefix, path), data=file_obj) return True
def _create_criterion(self, model_obj, col_name, operator, query): """ Create a SQLAlchemy criterion from filter parts Args: model_obj: :class:`sqlalchemy.Table` Table class to search col_name: :type:`str` Name of column to query operator: :type:`str` Domain operator to use in query query: :type:`str` Text to search for Returns: SQLAlchemy criterion representing a single WHERE clause Raises: NotImplementedError: When query operator is not implemented AttributeError: When col_name does not exist in the model_obj """ try: col_obj = getattr(model_obj, col_name) operator_obj = self.FILTERS[operator] return operator_obj(col_obj, query) except KeyError: raise except AttributeError: raise def _unwrap_filters(self, model_obj, filters=None): """ Unwrap a dictionary of filters into something usable by SQLAlchemy :param model_obj: Table class to search :type model_obj: :class:`sqlalchemy.Table` :param filters: Filters, keyed by col name :type filters: dict :rtype: list """ if filters is None: filters = {} new_filters = [] for col_name, col_filter in filters.items(): if isinstance(col_filter, dict): for _operator, _filter in col_filter.items(): new_filters.append(self._create_criterion( model_obj, col_name, _operator, _filter )) elif isinstance(col_filter, (list, tuple)): query = [] for _filter in col_filter: query.append( self._create_criterion( model_obj, col_name, '==', _filter, ), ) new_filters.append(or_(*query)) else: new_filters.append(self._create_criterion( model_obj, col_name, '==', col_filter )) return new_filters def _create_entities(self, model_obj, cols): """ Return list of entities matching cols :param model_obj: Table class to search :type model_obj: :class:`sqlalchemy.Table` :param cols: List of col names :type cols: list :rtype: :type:`list` of :class:`sqlalchemy.Column` """ out = [] for col in cols: try: out.append(getattr(model_obj, col)) except AttributeError: pass return out
[docs] def read(self, model_obj, record_id, with_entities=None): """ Get record by id and return the object :param model_obj: Table class to search :type model_obj: :class:`sqlalchemy.Table` :param record_id: Id of record to manipulate :param with_entities: Attributes to rcv from db. None for * :type with_entities: list or None :param with_entities: List of col names to select, None for all :type with_entities: list or None :rtype: :class:`sqlalchemy.engine.ResultProxy` """ with self._get_model_session(model_obj) as session: res = session.query(model_obj).get(record_id) if with_entities: res.with_entities(*self._create_entities( model_obj, with_entities )) return res
[docs] def search(self, model_obj, filters=None, with_entities=None): """ Search table by filters and return records :param model_obj: Table class to search :type model_obj: :class:`sqlalchemy.schema.Table` :param filters: Filters to apply to search :type filters: dict or None :param with_entities: List of col names to select, None for all :type with_entities: list or None :rtype: :class:`sqlalchemy.engine.ResultProxy` """ with self._get_model_session(model_obj) as session: if filters is None: filters = {} filters = self._unwrap_filters(model_obj, filters) res = session.query(model_obj).filter(*filters) if with_entities: res.with_entities(*self._create_entities( model_obj, with_entities )) return res
[docs] def create(self, model_obj, vals): """ Wrapper to create a record in Carepoint :param model_obj: Table class to create with :type model_obj: :class:`sqlalchemy.schema.Table` :param vals: Data to create record with :type vals: dict :rtype: :class:`sqlalchemy.ext.declarative.Declarative` """ with self._get_model_session(model_obj) as session: record = model_obj(**vals) session.add(record) return record
[docs] def update(self, model_obj, record_id, vals): """ Wrapper to update a record in Carepoint :param model_obj: Table class to update :type model_obj: :class:`sqlalchemy.schema.Table` :param record_id: Id of record to manipulate :type record_id: int :param vals: Data to create record with :type vals: dict :rtype: :class:`sqlalchemy.ext.declarative.Declarative` """ with self._get_model_session(model_obj): record = self.read(model_obj, record_id) for key, val in vals.items(): setattr(record, key, val) return record
[docs] def delete(self, model_obj, record_id): """ Wrapper to delete a record in Carepoint :param model_obj: Table class to update :type model_obj: :class:`sqlalchemy.schema.Table` :param record_id: Id of record to manipulate :type record_id: int :return: Whether the record was found, and deleted :rtype: bool """ with self._get_model_session(model_obj) as session: record = self.read(model_obj, record_id) result_cnt = record.count() if result_cnt == 0: return False assert result_cnt == 1 session.delete(record) return True
[docs] def get_pks(self, model_obj): """ Return the Primary keys in the model :param model_obj: Table class to update :type model_obj: :class:`sqlalchemy.schema.Table` :return: Tuple of primary key name strings :rtype: tuple """ return tuple(k.name for k in inspect(model_obj).primary_key)
[docs] def get_next_sequence(self, sequence_name, db_name='cph'): """ It generates and returns the next int in sequence Params: sequence_name: ``str`` Name of the sequence in Carepoint DB db_name: ``str`` Name of DB containing sequence stored proc Return: Integer to use as pk """ with self._get_session(db_name) as session: res = session.connection().execute( text( "SET NOCOUNT ON;" "DECLARE @out int = 0;" "EXEC CsGenerateIntId :seq_name, @out output;" "SELECT @out;" "SET NOCOUNT OFF;", bindparams=[bindparam('seq_name')], ), seq_name=sequence_name, ) id_int = res.fetchall()[0][0] return id_int
def __getattr__(self, key): """ Re-implement __getattr__ to use __getitem__ if attr not found """ try: return super(Carepoint, self).__getattr__(key) except AttributeError: try: self.__getitem__(key) except KeyError: raise AttributeError() def __setitem__(self, key, val, __global=False, *args, **kwargs): """ Re-implement __setitem__ to allow for global model sync """ super(Carepoint, self).__setitem__(key, val, *args, **kwargs) if not __global: global models models[key] = val def __getitem__(self, key, retry=True, default=False): """ Re-implement __getitem__ to scan for models if key missing """ global models for k, v in models.iteritems(): self.__setitem__(k, v, True) try: return super(Carepoint, self).__getitem__(key) except KeyError: if default is not False: return default elif retry: self.find_models() return self.__getitem__(key, False) else: raise KeyError( 'Plugin "%s" not found in model_dir "%s"' % ( key, self.model_path ) )
[docs] def set_iter_refresh(self, refresh=True): """ Toggle flag to search for new models before iteration :param refresh: Whether to refresh before iteration :type refresh: bool """ self.iter_refresh = refresh
def __refresh_models__(self): if self.iter_refresh: self.find_models() def __iter__(self): """ Reimplement __iter__ to allow for optional model refresh """ self.__refresh_models__() return super(Carepoint, self).__iter__()
[docs] def values(self): """ Reimplement values to allow for optional model refresh """ self.__refresh_models__() return super(Carepoint, self).values()
[docs] def keys(self): """ Reimplement keys to allow for optional model refresh """ self.__refresh_models__() return super(Carepoint, self).keys()
[docs] def items(self): """ Reimplement items to allow for optional model refresh """ self.__refresh_models__() return super(Carepoint, self).items()
[docs] def itervalues(self): """ Reimplement itervalues to allow for optional model refresh """ self.__refresh_models__() return super(Carepoint, self).itervalues()
[docs] def iterkeys(self): """ Reimplement iterkeys to allow for optional model refresh """ self.__refresh_models__() return super(Carepoint, self).iterkeys()
[docs] def iteritems(self): """ Reimplement iteritems to allow for optional model refresh """ self.__refresh_models__() return super(Carepoint, self).iteritems()
[docs] def register_model(self, model_obj): """ Registration logic + append to models struct :param model_obj: Model object to register :type model_obj: :class:`sqlalchemy.ext.declarative.Declarative` """ self[model_obj.__name__] = model_obj
[docs] def register_model_dir(self, model_path): """ This function sets the model path to be searched :param model_path: Path of models :type model_path: str """ if os.path.isdir(model_path): self.model_path = model_path else: raise EnvironmentError('%s is not a directory' % model_path)
[docs] def find_models(self, model_path=None): """ Traverse registered model directory and import non-loaded modules """ if model_path is None: model_path = self.model_path if model_path is not None and not os.path.isdir(model_path): raise EnvironmentError('%s is not a directory' % model_path) for dir_name, subdirs, files in os.walk(model_path): if dir_name.startswith('__'): continue dir_name = os.path.abspath(dir_name) parent_module = dir_name.replace(model_path, '') parent_module = parent_module.replace(os.path.sep, '.') for file_ in files: if file_.endswith('.py') and file_ != '__init__.py': module = file_[:-3] mod_obj = globals().get(module) if mod_obj is None: f, filename, desc = imp.find_module( module, [dir_name] ) mod_obj = imp.load_module( module, f, filename, desc ) cls = [ m for m in dir(mod_obj) if not m.startswith('__') ] for model_cls in cls: model_obj = getattr(mod_obj, model_cls) if hasattr(model_obj, '__tablename__'): if not hasattr(model_obj, '__dbname__'): model_obj.__dbname__ = self.DEFAULT_DB self.register_model(model_obj)