Source code for neurobooth_terra.postgres

# Authors: Mainak Jas <mjas@mgh.harvard.edu>

import pandas as pd

import psycopg2
import psycopg2.extras as extras

#### Some useful PSQL commands
# pg_ctl -D /usr/local/var/postgres start  --> start server
# psql mydatabasename

#### Monkeypatch psycopg2 functions ####

def execute(conn, cursor, cmd, fetch=False):
    cursor.execute(cmd)
    conn.commit()
    if fetch:
        return cursor.fetchall()


def _execute_batch(conn, cursor, cmd, tuples, page_size=100):
    extras.execute_batch(cursor, cmd, tuples, page_size)
    conn.commit()


def _get_primary_keys(conn, cursor, table_id):
    query = (
    "SELECT a.attname "
    "FROM   pg_index i "
    "JOIN   pg_attribute a ON a.attrelid = i.indrelid "
                        "AND a.attnum = ANY(i.indkey) "
    f"WHERE  i.indrelid = '{table_id}'::regclass "
    "AND    i.indisprimary;"
    )
    column_names = execute(conn, cursor, query, fetch=True)
    primary_keys = [col[0] for col in column_names]
    return primary_keys


#### Neurobooth related comands #####

def df_to_psql(conn, cursor, df, table_id):
    """Convert a dataframe to a Postgres SQL table

    Parameters
    ----------
    conn : instance of psycopg2.Postgres
        The connection object
    cursor : instance of psycopg2.cursor
        The cursor object
    df : instance of pd.Dataframe
        The dataframe to insert into the table
    table_id : str
        The table_id to create
    """
    df = df.where(~df.isna(), None)
    tuples = [tuple(x) for x in df.to_numpy()]
    # Comma-separated dataframe columns
    cols = ','.join(list(df.columns))
    vals = ','.join(len(df.columns) * ['%s'])

    create_cmd = f'CREATE TABLE IF NOT EXISTS {table_id}('
    for col in df.columns[:-1]:
        create_cmd += f'{col} VARCHAR( 255 ), '
    create_cmd += f'{df.columns[-1]} VARCHAR ( 255 )'
    create_cmd += ');'
    execute(conn, cursor, create_cmd)

    insert_cmd = f'INSERT INTO {table_id}({cols}) VALUES({vals})'
    _execute_batch(conn, cursor, insert_cmd, tuples)


[docs]def query(conn, sql_query, column_names): """Transform a SELECT query into a pandas dataframe Parameters ---------- conn : instance of psycopg2.Postgres The connection object sql_query : str The SQL query to perform column_names : str | list of str The columns to create Returns ------- df : instance of Dataframe The pandas dataframe """ if isinstance(column_names, str): column_names = [column_names] cursor = conn.cursor() data = execute(conn, cursor, sql_query, fetch=True) df = pd.DataFrame(data, columns=column_names) cursor.close() return df
[docs]def drop_table(table_id, conn): """Drop table. Parameters ---------- table_id : str The table ID conn : instance of psycopg2.Postgres The connection object """ cursor = conn.cursor() cmd = f'DROP TABLE IF EXISTS "{table_id}" CASCADE;' execute(conn, cursor, cmd) cursor.close()
[docs]def create_table(table_id, conn, column_names, dtypes, primary_key=None, foreign_key=None, index=None): """Create a table. Parameters ---------- table_id : str The table ID conn : instance of psycopg2.Postgres The connection object column_names : list of str The columns to create dtypes : list of str The datatypes primary_key : str | None | list The primary key. If None, the first column name is used as primary key. If list, then primary key is a combination of the columns in the list. foreign_key : dict Foreign key referring to another table. The key is the name of the foreign key and value is the table it refers to. index : dict of list | None The key is the name of the index and values are the column names on which to create the unique index. """ # XXX: add check for columns if table already exists create_cmd = f'CREATE TABLE "{table_id}" (' if len(column_names) != len(dtypes): raise ValueError('Column names and data types should have equal lengths') if primary_key is None: primary_key = column_names[0] if isinstance(primary_key, str): primary_key = [primary_key] for column_name, dtype in zip(column_names, dtypes): create_cmd += f'"{column_name}" {dtype},' create_cmd += f'PRIMARY KEY({", ".join(primary_key)}),' if foreign_key is None: foreign_key = dict() for key in foreign_key: create_cmd += f"""FOREIGN KEY ({key}) REFERENCES {foreign_key[key]}({key}) """ create_cmd = create_cmd[:-1] + ');' # remove last comma cursor = conn.cursor() try: execute(conn, cursor, create_cmd) except Exception as e: cursor.close() raise Exception(e) if index is not None: index_name = list(index.keys())[0] index_cols = ', '.join(list(index.values())[0]) drop_cmd = f'DROP INDEX IF EXISTS {index_name}' execute(conn, cursor, drop_cmd) index_cmd = (f'CREATE UNIQUE INDEX {index_name} ON ' f'{table_id} ({index_cols});') execute(conn, cursor, index_cmd) return Table(table_id, conn=conn, cursor=cursor, primary_key=primary_key)
[docs]class Table: """Table class that is a wrapper around Postgres SQL table. Parameters ---------- table_id : str The table ID conn : instance of psycopg2.Postgres The connection object primary_key : str | None The primary key. If None, the first column name is used as primary key. Attributes ---------- column_names : list of str The column names data_types : list of str The data types of the column names primary_key : list of str The primary key. May be more than one in case of compound primary key. """
[docs] def __init__(self, table_id, conn, cursor=None, primary_key=None): self.conn = conn if cursor is None: cursor = conn.cursor() self.cursor = cursor self.table_id = table_id alias = {'character varying': 'VARCHAR'} cmd = ("SELECT column_name, data_type, character_maximum_length" " FROM INFORMATION_SCHEMA.COLUMNS WHERE " f"table_name = '{table_id}';") columns = execute(conn, cursor, cmd, fetch=True) self.column_names = list() self.data_types = list() for cn in columns: column_name, dtype, maxlen = cn if dtype == 'character varying': dtype = f'VARCHAR ({maxlen})' self.column_names.append(column_name) self.data_types.append(dtype.upper()) if primary_key is None: primary_key = _get_primary_keys(conn, cursor, table_id) if isinstance(primary_key, str): primary_key = [primary_key] self.primary_key = primary_key
def __repr__(self): repr_str = f'Table "{self.table_id}" ' repr_str += '(' + ', '.join(self.column_names) + ')' return repr_str def __enter__(self): return self def __exit__(self): self.close() def close(self): self.cursor.close()
[docs] def alter_column(self, col, default=None): """Alter a column in the table. Parameters ---------- col : str The column name. default : str | dict | None The default value of the column. If you want to specify a prefix that autoincrements, you can say: dict(prefix=prefix), e.g., dict(prefix='SUBJECT') """ cmd = f"ALTER TABLE {self.table_id} ALTER COLUMN {col} " if isinstance(default, str): cmd += f'SET DEFAULT {default}' execute(self.conn, self.cursor, cmd) elif isinstance(default, dict): sequence_name = f'{self.table_id}_{col}' seq_cmd1 = f'DROP SEQUENCE IF EXISTS {sequence_name}' seq_cmd2 = f'CREATE SEQUENCE IF NOT EXISTS {sequence_name}' execute(self.conn, self.cursor, seq_cmd1) execute(self.conn, self.cursor, seq_cmd2) prefix = default['prefix'] cmd += f"SET DEFAULT '{prefix}' || nextval('{sequence_name}')" execute(self.conn, self.cursor, cmd) constraint_name = f'{sequence_name}_chk' check_cmd = (f"ALTER TABLE {self.table_id} " f"ADD CONSTRAINT {constraint_name} " f"CHECK ({col} ~ '^{prefix}[0-9]+$')") execute(self.conn, self.cursor, check_cmd)
[docs] def add_column(self, col, dtype): """Add a new column to the table. Parameters ---------- col : str The column name. dtype : str The data type of the column. """ cmd = f'ALTER TABLE {self.table_id} ' cmd += f'ADD COLUMN {col} {dtype};' execute(self.conn, self.cursor, cmd) self.column_names.append(col)
[docs] def drop_column(self, col): """Drop a column from the table. Parameters ---------- col : str The column name. """ cmd = f'ALTER TABLE {self.table_id} ' cmd += f'DROP COLUMN {col} ' execute(self.conn, self.cursor, cmd) idx = self.column_names.index(col) del self.column_names[idx], self.data_types[idx]
[docs] def insert_rows(self, vals, cols, on_conflict='error', conflict_cols='auto', update_cols='all', where=None): """Manual insertion into tables Parameters ---------- vals : list of tuple The records to insert. Each tuple is one row. cols : list of str The columns to insert into. on_conflict : 'nothing' | 'update' | 'error' What to do when a conflict is encountered conflict_cols : 'auto' | str | list If 'auto', it uses primary key when on_conflict is 'update'. If list, uses the list of columns to create a unique index to infer conflicts. update_cols : 'all' | str | list If 'all', updates all the columns with the new values. If list, updates only those columns. where : str | None Condition to filter rows by. If None, keep all rows where primary key is not NULL. Returns ------- pk_val : str | None The primary keys of the row inserted into. If multiple rows are inserted, returns None. Notes ----- When conflict_cols is a list, a unique index must be set for the target columns. The following SQL command is handy: create unique index subject_identifier on subject (first_name_birth, last_name_birth, date_of_birth); """ if not isinstance(vals, list): raise ValueError(f'vals must be a list of tuple. Got {type(vals)}') vals_clean = list() for val in vals: if not isinstance(val, tuple): raise ValueError(f'entries in vals must be tuples. Got {type(val)}') if len(val) != len(cols): raise ValueError(f'tuple length must match number of columns ({len(cols)})') val = tuple([extras.Json(this_val) if isinstance(this_val, dict) else this_val for this_val in val]) vals_clean.append(val) vals = vals_clean if on_conflict not in ('nothing', 'update', 'error'): raise ValueError(f'on_conflict must be one of (nothing, update, error)', f'Got {on_conflict}') if conflict_cols == 'auto': conflict_cols = self.primary_key if isinstance(conflict_cols, str): conflict_cols = [conflict_cols] conflict_cols = ', '.join(conflict_cols) if update_cols == 'all': update_cols = cols.copy() if isinstance(update_cols, str): update_cols = [update_cols] if where is None: where = f'{self.table_id}.{self.primary_key[0]} is NOT NULL' str_format = ','.join(len(cols) * ['%s']) col_names = cols.copy() cols = ','.join([f'"{col}"' for col in cols]) insert_cmd = f'INSERT INTO {self.table_id}({cols}) VALUES({str_format}) ' if on_conflict == 'nothing': insert_cmd += f'ON CONFLICT DO NOTHING ' elif on_conflict == 'update': insert_cmd += f'ON CONFLICT ({conflict_cols})' insert_cmd += ' DO UPDATE SET ' update_cmd = list() for col_name in update_cols: update_cmd.append(f'"{col_name}" = excluded."{col_name}"') insert_cmd += ', '.join(update_cmd) + ' ' insert_cmd += f'WHERE {where} ' insert_cmd += f'RETURNING {self.primary_key[0]}' _execute_batch(self.conn, self.cursor, insert_cmd, vals) if len(vals) == 1 and on_conflict != 'do_nothing': return self.cursor.fetchone()[0]
[docs] def update_row(self, pk_val, vals, cols): """Update values in a row Parameters ---------- pk_val : str The value of the primary key to match the row to replace. vals : tuple The values in the row to replace. cols : list of str The columns to insert into. """ cmd = f"UPDATE {self.table_id} SET " if not isinstance(vals, tuple): raise ValueError('vals must be a tuple') if cols is None: cols = self.column_names if len(cols) != len(vals): raise ValueError(f'length of vals ({len(vals)}) != ' f'length of cols ({len(cols)})') for col, val in zip(cols, vals): if col not in self.column_names: raise ValueError(f'column {col} is not present in table') cmd += f"\"{col}\" = '{val}', " cmd = cmd[:-2] # remove last comma pk = self.primary_key[0] cmd += f" WHERE {pk} = '{pk_val}';" execute(self.conn, self.cursor, cmd)
[docs] def query(self, include_columns=None, where=None): """Run a query. Parameters ---------- include_columns : str | list of str | None If None, query all columns where : str | None Condition to filter rows by. If None, keep all rows. E.g., table.query(where='"wearable_bool" = True') Returns ------- df : instance of pd.Dataframe A pandas dataframe object. """ if include_columns is None: include_columns = self.column_names if isinstance(include_columns, str): include_columns = [include_columns] # use quotes to be case sensitive cols = ', '.join([f'\"{col}\"' for col in include_columns]) cmd = f"SELECT {cols} FROM {self.table_id} " if where is not None: cmd += f"WHERE {where}" cmd += ';' data = execute(self.conn, self.cursor, cmd, fetch=True) df = pd.DataFrame(data, columns=include_columns) pk = self.primary_key[0] if pk in df.columns: df = df.set_index(pk) return df
[docs] def delete_row(self, condition=None): """Delete rows from table. Parameters ---------- condition : str The condition to filter rows by and delete them. """ delete_cmd = f'DELETE FROM {self.table_id} ' if condition is not None: delete_cmd += f'WHERE {condition};' delete_cmd += ';' execute(self.conn, self.cursor, delete_cmd)
def drop(self): drop_table(self.conn, self.cursor, self.table_id)
[docs]def list_tables(conn): """List the table_ids in the database. Parameters ---------- conn : instance of psycopg2.Postgres The connection object Returns ------- table_ids : list of str The table IDs """ query_tables_cmd = """ SELECT * FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'; """ cursor = conn.cursor() cursor.execute(query_tables_cmd) tables = cursor.fetchall() cursor.close() table_ids = [table[1] for table in tables] return table_ids