Source code for clld.db.migration

"""Functionality for alembic scripts.

This module provides

- basic crud functionality within alembic migration scripts,
- advanced helpers for special tasks, like merging sources.

.. note::

    Using the functionality provided in this module is not possible for Alembic scripts
    supposed to be run in
    `offline mode <\

from sqlalchemy.sql import select as base_select, func
from clld.db.models import common

def _with_where_clause(table, stmt, **where):
    for key, value in where.items():
        stmt = stmt.where(getattr(table.c, key) == value)
    return stmt

[docs]class Connection(object): """A wrapper around an SQLAlchemy connection. This wrapper provides the convenience of allowing typical CRUD operations to be called passing model classes. Additionally, it implements more complicated clld domain specific database operations. A ``Connection`` will typically be instantiated in an Alembic migration script as follows: .. code-block:: python from alembic import op conn = Connection(op.get_bind()) """ def __init__(self, conn): """Initialize. :param conn: SQLAlchemy connection object, i.e. anything with a ``execute`` method; so in particular a session qualifies as well as an engine, i.e. what ``get_bind()`` returns. """ self._conn = conn
[docs] def execute(self, *args, **kw): """Provide access to the underlying connection's ``execute`` method.""" return self._conn.execute(*args, **kw)
# # CRUD operations: #
[docs] def select(self, model, **where): """Run a select statement and return a ResultProxy.""" table = model.__table__ return self.execute(_with_where_clause(table, base_select([table]), **where))
[docs] def all(self, model, **where): """return all results of a select statement.""" return, **where).fetchall()
[docs] def first(self, model, **where): """return first result of a select statement or ``None``.""" return, **where).fetchone()
[docs] def get(self, model, pk): """return row specified by primary key.""" return self.first(model, pk=pk)
[docs] def pk(self, model, id_, attr='id'): """Get the primary key of an object specified by a unique property. :param model: model class. :param id_: Value to be used when filtering. :param attr: Column to be used for filtering. :return: primary key of (first) matching row. """ res = self.first(model, **{attr: id_}) if res: return
[docs] def insert(self, model, **values): """Run an insert statement. :return: primary key of the inserted row. """ for k, v in [ ('version', 1), ('created',, ('updated',, ('active', True) ]: if hasattr(model.__table__.c, k): values.setdefault(k, v) res = self.execute(model.__table__.insert().values(**values)) return res.inserted_primary_key[0]
[docs] def update(self, model, values, **where): """Run an update statement.""" if not isinstance(values, dict): values = dict(values) if hasattr(model.__table__.c, 'updated'): values.setdefault('updated', table = model.__table__ self.execute(_with_where_clause(table, table.update(), **where).values(**values))
[docs] def delete(self, model, **where): """Run a delete statement.""" self.execute( _with_where_clause(model.__table__, model.__table__.delete(), **where))
# # domain specific operations: #
[docs] def set_glottocode(self, lid, gc, gcid=None): """assign a unique glottocode to a language. i.e. alternative glottocodes will be deleted. :param lid: ``id`` of the language. :param gc: Glottocode to be assigned. :param gcid: ``id`` of the ``Identifier`` instance if one has to be created; defaults to ``gc``. """ lpk =, lid) gctype = common.IdentifierType.glottolog.value done = False lis = self.all(common.LanguageIdentifier, language_pk=lpk) for li in lis: i = self.get(common.Identifier, li.identifier_pk) if i.type == gctype: if == gc: done = True else: self.delete(common.LanguageIdentifier, if not done: # create a new relation i = self.first(common.Identifier, name=gc, type=gctype) if i: ipk = else: ipk = self.insert(common.Identifier, id=gcid or gc, name=gc, type=gctype) return self.insert( common.LanguageIdentifier, language_pk=lpk, identifier_pk=ipk)
def replace(self, model, from_id, to_id): # pragma: no cover self.insert( common.Config, key=common.Config.replacement_key(model, from_id), value=to_id)
def merge_sources(conn, from_id, to_id, *fields): # pragma: no cover if not isinstance(conn, Connection): conn = Connection(conn) # resolve id to pk fpk =, from_id) tpk =, to_id) if fields: conn.execute("""\ UPDATE source SET %s FROM (SELECT %s FROM source WHERE pk = %s) AS old WHERE pk = %s""" % (', '.join('{0}=old.{0}'.format(f) for f in fields), ', '.join(fields), fpk, tpk)) # update relationships for row in conn.all(common.LanguageSource, source_pk=fpk): lpk = row['language_pk'] if conn.first(common.LanguageSource, language_pk=lpk, source_pk=tpk): conn.delete(common.LanguageSource, source_pk=fpk) else: conn.update(common.LanguageSource, dict(source_pk=tpk), source_pk=fpk) for model in [ common.SentenceReference, common.ContributionReference, common.ValueSetReference ]: conn.update(model, dict(source_pk=tpk), source_pk=fpk) # add replacement record conn.replace(common.Source, from_id, to_id) # remove source conn.delete(common.Source, id=from_id)