"""We provide some infrastructure to build extensible database models."""
import sqlite3
try: # pragma: no cover
import simplejson as json
except ImportError: # pragma: no cover
import json
from six import string_types, text_type, PY2
from pytz import UTC
from sqlalchemy import (
Column, Integer, Float, String, Boolean, DateTime, func, event)
from sqlalchemy.exc import DisconnectionError
from sqlalchemy.pool import Pool
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import scoped_session, sessionmaker, deferred, undefer
from sqlalchemy.types import TypeDecorator, VARCHAR
from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound
from sqlalchemy.orm.query import Query
from sqlalchemy.inspection import inspect
from sqlalchemy.dialects.postgresql import TSVECTOR
from zope.sqlalchemy import ZopeTransactionExtension
from clldutils.misc import NO_DEFAULT, UnicodeMixin
from clldutils import jsonlib
from clld.db.versioned import versioned_session
@event.listens_for(Pool, "checkout")
def ping_connection(dbapi_connection, connection_record, connection_proxy):
"""Event listener to handle disconnects.
Implements
`pessimistic disconnect handling <http://docs.sqlalchemy.org/en/rel_0_9/core/\
pooling.html#disconnect-handling-pessimistic>`_.
.. note::
Our implementation is mildly dialect specific, but works for sqlite and
PostgreSQL. For oracle, the 'ping' query should read *SELECT 1 FROM DUAL* or
similar.
"""
cursor = dbapi_connection.cursor()
try:
cursor.execute("SELECT 1")
if not isinstance(dbapi_connection, sqlite3.Connection): # pragma: no cover
cursor.execute("SET default_text_search_config TO 'english'")
except: # pragma: no cover
# dispose the whole pool instead of invalidating one at a time
connection_proxy._pool.dispose()
# raise DisconnectionError - pool will try
# connecting again up to three times before raising.
raise DisconnectionError()
cursor.close()
class ActiveOnlyQuery(Query): # pragma: no cover
"""A pre-filtering query.
Implements a
`pre-filtering query <http://www.sqlalchemy.org/trac/wiki/UsageRecipes/\
PreFilteredQuery>`_ that filters on the :py:attr:`clld.db.meta._Base.active` flag.
"""
def get(self, ident):
# override get() so that the flag is always checked in the
# DB as opposed to pulling from the identity map.
return Query.get(self.populate_existing(), ident)
def __iter__(self):
return Query.__iter__(self.private())
def from_self(self, *ent):
# override from_self() to automatically apply
# the criterion too. this works with count() and
# others.
return Query.from_self(self.private(), *ent)
def private(self):
mzero = self._mapper_zero()
if mzero is not None:
crit = mzero.class_.active == True
return self.enable_assertions(False).filter(crit)
else:
return self
DBSession = scoped_session(sessionmaker(extension=ZopeTransactionExtension()))
ActiveOnlyDBSession = scoped_session(sessionmaker(
extension=ZopeTransactionExtension(), query_cls=ActiveOnlyQuery))
VersionedDBSession = scoped_session(versioned_session(
sessionmaker(autoflush=False, extension=ZopeTransactionExtension())))
class JSONEncodedDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string.
Loads/serializes an empty dict for any empty value.
"""
impl = VARCHAR
def process_bind_param(self, value, dialect):
if not value:
value = {}
return json.dumps(value)
def process_result_value(self, value, dialect):
if not value:
return {}
return json.loads(value)
def _solr_timestamp(dt):
if not dt:
return
try:
dt = dt.astimezone(UTC)
except ValueError:
pass
return dt.isoformat().split('+')[0] + 'Z'
class CsvMixin(object):
"""Mixin providing methods to control (de-)serialization of an object as csv row."""
#: base name of the csv file
__csv_name__ = None
@classmethod
def csv_head(cls):
"""return List of column names."""
exclude = {'active', 'version', 'created', 'updated', 'polymorphic_type'}
cols = sorted(
col.key for om in inspect(cls).iterate_to_root()
for col in om.local_table.c
if (col.key not in exclude
and col.type.__class__ not in [TSVECTOR]
and not exclude.add(col.key)))
return cols
def value_to_csv(self, attr, ctx=None, req=None):
"""Convert one value to a representation suitable for csv writer.
:param attr: Name of the attribute from which to convert the value.
:return: Object suitable for serialization with csv writer.
"""
rel = None
if attr.endswith('__ids') or attr.endswith('__id'):
attr = attr.split('__')
rel = attr[-1]
attr = '__'.join(attr[:-1])
prop = getattr(self, attr, '')
if attr == 'jsondata':
prop = json.dumps(prop)
if PY2:
prop = prop.decode('utf8')
if rel == 'id' and hasattr(prop, 'id'):
return prop.id
elif rel == 'ids':
return ','.join('%s' % o.id for o in prop)
return prop
def to_csv(self, ctx=None, req=None, cols=None):
"""return list of values to be passed to csv.writer.writerow."""
return [self.value_to_csv(attr, ctx, req) for attr in cols or self.csv_head()]
@classmethod
def value_from_csv(cls, attr, value):
if not value:
return None
col = getattr(cls, attr)
if hasattr(col, 'property') and hasattr(col.property, 'columns'):
if isinstance(col.property.columns[0].type, Integer):
return int(value)
if isinstance(col.property.columns[0].type, Float):
if isinstance(value, string_types):
value = value.replace(',', '.')
return float(value)
return value
@classmethod
def from_csv(cls, row, data=None, cols=None):
obj = cls()
cols = cols or obj.csv_head()
for i, k in enumerate(cols):
if not (k.endswith('__id') or k.endswith('__ids')) and hasattr(obj, k):
setattr(obj, k, cls.value_from_csv(k, row[i]) or None)
return obj
@classmethod
def csv_query(cls, session):
query = session.query(cls).filter_by(active=True)
return query.order_by(getattr(cls, 'id', getattr(cls, 'pk', None)))
[docs]class Base(UnicodeMixin, CsvMixin, declarative_base()):
"""The declarative base for all our models."""
__abstract__ = True
@declared_attr
def __tablename__(cls):
"""We derive the table name from the model class name.
This should be safe,
because we don't want to have model classes with the same name either.
Care has to be taken, though, to prevent collisions with the names of tables
which are automatically created (history tables for example).
"""
return cls.__name__.lower()
#: All our models have an integer primary key which has nothing to do with
#: the kind of data stored in a table. 'Natural' candidates for primary keys
#: should be marked with unique constraints instead. This adds flexibility
#: when it comes to database changes.
pk = Column(Integer, primary_key=True, doc='primary key')
#: To allow for timestamp-based versioning - as opposed or in addition to the version
#: number approach implemented in :py:class:`clld.db.meta.Versioned` - we store
#: a timestamp for creation or an object.
@declared_attr
def created(cls):
return deferred(Column(DateTime(timezone=True), default=func.now()))
#: Timestamp for latest update of an object.
@declared_attr
def updated(cls):
return deferred(Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()))
#: The active flag is meant as an easy way to mark records as obsolete or inactive,
#: without actually deleting them. A custom Query class could then be used which
#: filters out inactive records.
@declared_attr
def active(cls):
return deferred(Column(Boolean, default=True))
#: To allow storage of arbitrary key,value pairs with typed values, each model
#: provides a column to store JSON encoded dicts.
jsondata = Column(JSONEncodedDict)
def __init__(self, jsondata=None, **kwargs):
kwargs['jsondata'] = jsondata or {}
super(Base, self).__init__(**kwargs)
[docs] def update_jsondata(self, **kw):
"""Convenience function.
Since we use the simple
`JSON encoded dict recipe <http://docs.sqlalchemy.org/en/rel_0_9/core/types.html\
#marshal-json-strings>`_
without mutation tracking, we provide a convenience method to update
"""
d = self.jsondata.copy()
d.update(kw)
self.jsondata = d
@property
def jsondatadict(self):
"""Deprecated convenience function.
Use jsondata directly instead, which is guaranteed to be a dictionary.
"""
return self.jsondata or {}
@property
def replacement_id(self):
"""Used to allow automatically redirecting to a 'better' version of a resource."""
if not self.active:
return self.jsondata.get('__replacement_id__')
[docs] @classmethod
def get(cls, value, key=None, default=NO_DEFAULT, session=None):
"""Convenience method to query a model where exactly one result is expected.
e.g. to retrieve an instance by primary key or id.
:param value: The value used in the filter expression of the query.
:param str key: The key or attribute name to be used in the filter expression. If\
None is passed, defaults to *pk* if value is ``int`` otherwise to *id*.
"""
session = session or DBSession
if key is None:
key = 'pk' if isinstance(value, int) else 'id'
try:
return session.query(cls)\
.options(undefer('updated')).filter_by(**{key: value}).one()
except (NoResultFound, MultipleResultsFound):
if default is NO_DEFAULT:
raise
return default
@classmethod
def first(cls):
"""More convenience."""
return DBSession.query(cls).order_by(cls.pk).first()
[docs] def history(self):
"""return result proxy to iterate over previous versions of a record."""
model = self.__class__
if not hasattr(model, '__history_mapper__'):
return [] # pragma: no cover
history_class = model.__history_mapper__.class_
return DBSession.query(history_class).filter(history_class.pk == self.pk)\
.order_by(history_class.version.desc())
def __json__(self, req):
"""Custom JSON serialization of an object.
:param req: pyramid Request object.
:return: ``dict`` suitable for serialization as JSON.
"""
exclude = {'active', 'version', 'created', 'updated', 'polymorphic_type'}
cols = [
col.key for om in inspect(self).mapper.iterate_to_root()
for col in om.local_table.c
if col.key not in exclude and not exclude.add(col.key)]
return {col: jsonlib.format(getattr(self, col)) for col in cols}
def __solr__(self, req):
"""Custom solr document representing the object.
:param req: pyramid Request object.
:return: ``dict`` suitable as JSON encoded \
`Solr <https://lucene.apache.org/solr/>`_ document.
.. note::
The document returned by this method does only make sense when used with an
appropriate Solr schema. In particular we rely on name conventions for
`dynamic fields <https://cwiki.apache.org/confluence/display/solr/\
Dynamic+Fields>`_.
"""
cls = inspect(self).class_
if not is_base(cls):
for base in cls.__bases__:
if is_base(base):
cls = base
break
res = dict(
id=getattr(self, 'id', str(self.pk)),
url=req.resource_url(self) if req else None,
dataset=req.dataset.id if req else None,
rscname=cls.__name__,
name=getattr(self, 'name', '%s %s' % (self.__class__.__name__, self.pk)),
active=self.active,
)
for attr in ['updated', 'created']:
value = _solr_timestamp(getattr(self, attr))
if value:
res[attr] = value
suffix_map = [(text_type, '_t'), (bool, '_b'), (int, '_i'), (float, '_f')]
for om in inspect(self).mapper.iterate_to_root():
for col in om.local_table.c:
if col.key not in res and col.key != 'polymorphic_type':
value = getattr(self, col.key)
for type_, suffix in suffix_map:
if isinstance(value, type_):
res[col.key + suffix] = value
break
return res
def __unicode__(self):
"""A human readable label for the object."""
r = getattr(self, 'name', None)
if not r:
r = getattr(self, 'id', None)
if not r:
r = '%s%s' % (self.__class__.__name__, self.pk)
return r
def __repr__(self):
return '<%s %r>' % (
self.__class__.__name__, getattr(self, 'id', self.pk))
class PolymorphicBaseMixin(object):
"""Mixin providing the wiring for joined table inheritance.
We use joined table inheritance to allow projects to augment base ``clld``
models with project specific attributes. This mixin class prepares
models to serve as base classes for inheritance.
"""
polymorphic_type = Column(String(20))
@declared_attr
def __mapper_args__(cls):
return {
'polymorphic_on': cls.polymorphic_type,
'polymorphic_identity': 'base',
'with_polymorphic': '*',
}
def is_base(cls):
"""Determine whether a class is a base class or an inheriting one.
:param cls: Model class.
:return: ``bool`` signaling whether ``cls`` is a base class or derived, i.e.\
customized.
"""
# replace with inspection?
# see http://docs.sqlalchemy.org/en/rel_0_9/orm/mapper_config.html
# ?highlight=polymorphic_identity#sqlalchemy.orm.mapper.Mapper.polymorphic_identity
return PolymorphicBaseMixin in cls.__bases__
[docs]class CustomModelMixin(object):
"""Mixin for customized classes in our joined table inheritance scheme.
.. note::
With this scheme there can be only one specialized mapper class per inheritable
base class.
"""
@declared_attr
def __mapper_args__(cls):
return {'polymorphic_identity': 'custom'} # pragma: no cover