mirror of
https://github.com/welton89/RRBEC.git
synced 2026-04-06 05:55:42 +00:00
1404 lines
62 KiB
Python
1404 lines
62 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
sqldiff.py - Prints the (approximated) difference between models and database
|
|
|
|
TODO:
|
|
- better support for relations
|
|
- better support for constraints (mainly postgresql?)
|
|
- support for table spaces with postgresql
|
|
- when a table is not managed (meta.managed==False) then only do a one-way
|
|
sqldiff ? show differences from db->table but not the other way around since
|
|
it's not managed.
|
|
|
|
KNOWN ISSUES:
|
|
- MySQL has by far the most problems with introspection. Please be
|
|
carefull when using MySQL with sqldiff.
|
|
- Booleans are reported back as Integers, so there's no way to know if
|
|
there was a real change.
|
|
- Varchar sizes are reported back without unicode support so their size
|
|
may change in comparison to the real length of the varchar.
|
|
- Some of the 'fixes' to counter these problems might create false
|
|
positives or false negatives.
|
|
"""
|
|
|
|
import importlib
|
|
import sys
|
|
import argparse
|
|
from typing import Dict, Union, Callable, Optional # NOQA
|
|
from django.apps import apps
|
|
from django.core.management import BaseCommand, CommandError
|
|
from django.core.management.base import OutputWrapper
|
|
from django.core.management.color import no_style
|
|
from django.db import connection, transaction, models
|
|
from django.db.models import UniqueConstraint
|
|
from django.db.models.fields import AutoField, IntegerField
|
|
from django.db.models.options import normalize_together
|
|
|
|
from django_extensions.management.utils import signalcommand
|
|
|
|
ORDERING_FIELD = IntegerField('_order', null=True)
|
|
|
|
|
|
def flatten(lst, ltypes=(list, tuple)):
|
|
ltype = type(lst)
|
|
lst = list(lst)
|
|
i = 0
|
|
while i < len(lst):
|
|
while isinstance(lst[i], ltypes):
|
|
if not lst[i]:
|
|
lst.pop(i)
|
|
i -= 1
|
|
break
|
|
else:
|
|
lst[i:i + 1] = lst[i]
|
|
i += 1
|
|
return ltype(lst)
|
|
|
|
|
|
def all_local_fields(meta):
|
|
all_fields = []
|
|
if meta.proxy:
|
|
for parent in meta.parents:
|
|
all_fields.extend(all_local_fields(parent._meta))
|
|
else:
|
|
for f in meta.local_fields:
|
|
col_type = f.db_type(connection=connection)
|
|
if col_type is None:
|
|
continue
|
|
all_fields.append(f)
|
|
return all_fields
|
|
|
|
|
|
class SQLDiff:
|
|
DATA_TYPES_REVERSE_OVERRIDE = {} # type: Dict[int, Union[str, Callable]]
|
|
|
|
IGNORE_MISSING_TABLES = [
|
|
"django_migrations",
|
|
]
|
|
|
|
DIFF_TYPES = [
|
|
'error',
|
|
'comment',
|
|
'table-missing-in-db',
|
|
'table-missing-in-model',
|
|
'field-missing-in-db',
|
|
'field-missing-in-model',
|
|
'fkey-missing-in-db',
|
|
'fkey-missing-in-model',
|
|
'index-missing-in-db',
|
|
'index-missing-in-model',
|
|
'unique-missing-in-db',
|
|
'unique-missing-in-model',
|
|
'field-type-differ',
|
|
'field-parameter-differ',
|
|
'notnull-differ',
|
|
]
|
|
DIFF_TEXTS = {
|
|
'error': 'error: %(0)s',
|
|
'comment': 'comment: %(0)s',
|
|
'table-missing-in-db': "table '%(0)s' missing in database",
|
|
'table-missing-in-model': "table '%(0)s' missing in models",
|
|
'field-missing-in-db': "field '%(1)s' defined in model but missing in database",
|
|
'field-missing-in-model': "field '%(1)s' defined in database but missing in model",
|
|
'fkey-missing-in-db': "field '%(1)s' FOREIGN KEY defined in model but missing in database",
|
|
'fkey-missing-in-model': "field '%(1)s' FOREIGN KEY defined in database but missing in model",
|
|
'index-missing-in-db': "field '%(1)s' INDEX named '%(2)s' defined in model but missing in database",
|
|
'index-missing-in-model': "field '%(1)s' INDEX defined in database schema but missing in model",
|
|
'unique-missing-in-db': "field '%(1)s' UNIQUE named '%(2)s' defined in model but missing in database",
|
|
'unique-missing-in-model': "field '%(1)s' UNIQUE defined in database schema but missing in model",
|
|
'field-type-differ': "field '%(1)s' not of same type: db='%(3)s', model='%(2)s'",
|
|
'field-parameter-differ': "field '%(1)s' parameters differ: db='%(3)s', model='%(2)s'",
|
|
'notnull-differ': "field '%(1)s' null constraint should be '%(2)s' in the database",
|
|
}
|
|
|
|
SQL_FIELD_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD('ADD COLUMN'),
|
|
style.SQL_FIELD(qn(args[1])),
|
|
' '.join(style.SQL_COLTYPE(a) if i == 0 else style.SQL_KEYWORD(a) for i, a in enumerate(args[2:]))
|
|
)
|
|
SQL_FIELD_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s\n\t%s %s;" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD('DROP COLUMN'),
|
|
style.SQL_FIELD(qn(args[1]))
|
|
)
|
|
SQL_FKEY_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s %s (%s)%s;" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD('ADD COLUMN'),
|
|
style.SQL_FIELD(qn(args[1])),
|
|
' '.join(style.SQL_COLTYPE(a) if i == 0 else style.SQL_KEYWORD(a) for i, a in enumerate(args[4:])),
|
|
style.SQL_KEYWORD('REFERENCES'),
|
|
style.SQL_TABLE(qn(args[2])),
|
|
style.SQL_FIELD(qn(args[3])),
|
|
connection.ops.deferrable_sql()
|
|
)
|
|
SQL_INDEX_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s (%s%s);" % (
|
|
style.SQL_KEYWORD('CREATE INDEX'),
|
|
style.SQL_TABLE(qn(args[2])),
|
|
# style.SQL_TABLE(qn("%s" % '_'.join('_'.join(a) if isinstance(a, (list, tuple)) else a for a in args[0:3] if a))),
|
|
style.SQL_KEYWORD('ON'), style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_FIELD(', '.join(qn(e) for e in args[1])),
|
|
style.SQL_KEYWORD(args[3])
|
|
)
|
|
SQL_INDEX_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s;" % (
|
|
style.SQL_KEYWORD('DROP INDEX'),
|
|
style.SQL_TABLE(qn(args[1]))
|
|
)
|
|
SQL_UNIQUE_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s (%s);" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD('ADD CONSTRAINT'),
|
|
style.SQL_TABLE(qn(args[2])),
|
|
style.SQL_KEYWORD('UNIQUE'),
|
|
style.SQL_FIELD(', '.join(qn(e) for e in args[1]))
|
|
)
|
|
SQL_UNIQUE_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD('DROP'),
|
|
style.SQL_KEYWORD('CONSTRAINT'),
|
|
style.SQL_TABLE(qn(args[1]))
|
|
)
|
|
SQL_FIELD_TYPE_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD("MODIFY"),
|
|
style.SQL_FIELD(qn(args[1])),
|
|
style.SQL_COLTYPE(args[2])
|
|
)
|
|
SQL_FIELD_PARAMETER_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD("MODIFY"),
|
|
style.SQL_FIELD(qn(args[1])),
|
|
style.SQL_COLTYPE(args[2])
|
|
)
|
|
SQL_NOTNULL_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (
|
|
style.SQL_KEYWORD('ALTER TABLE'),
|
|
style.SQL_TABLE(qn(args[0])),
|
|
style.SQL_KEYWORD('MODIFY'),
|
|
style.SQL_FIELD(qn(args[1])),
|
|
style.SQL_KEYWORD(args[2]),
|
|
style.SQL_KEYWORD('NOT NULL')
|
|
)
|
|
SQL_ERROR = lambda self, style, qn, args: style.NOTICE('-- Error: %s' % style.ERROR(args[0]))
|
|
SQL_COMMENT = lambda self, style, qn, args: style.NOTICE('-- Comment: %s' % style.SQL_TABLE(args[0]))
|
|
SQL_TABLE_MISSING_IN_DB = lambda self, style, qn, args: style.NOTICE('-- Table missing: %s' % args[0])
|
|
SQL_TABLE_MISSING_IN_MODEL = lambda self, style, qn, args: style.NOTICE('-- Model missing for table: %s' % args[0])
|
|
|
|
can_detect_notnull_differ = False
|
|
can_detect_unsigned_differ = False
|
|
unsigned_suffix = None # type: Optional[str]
|
|
|
|
def __init__(self, app_models, options, stdout, stderr):
|
|
self.has_differences = None
|
|
self.app_models = app_models
|
|
self.options = options
|
|
self.dense = options['dense_output']
|
|
self.stdout = stdout
|
|
self.stderr = stderr
|
|
|
|
self.introspection = connection.introspection
|
|
|
|
self.differences = []
|
|
self.unknown_db_fields = {}
|
|
self.new_db_fields = set()
|
|
self.null = {}
|
|
self.unsigned = set()
|
|
|
|
self.DIFF_SQL = {
|
|
'error': self.SQL_ERROR,
|
|
'comment': self.SQL_COMMENT,
|
|
'table-missing-in-db': self.SQL_TABLE_MISSING_IN_DB,
|
|
'table-missing-in-model': self.SQL_TABLE_MISSING_IN_MODEL,
|
|
'field-missing-in-db': self.SQL_FIELD_MISSING_IN_DB,
|
|
'field-missing-in-model': self.SQL_FIELD_MISSING_IN_MODEL,
|
|
'fkey-missing-in-db': self.SQL_FKEY_MISSING_IN_DB,
|
|
'fkey-missing-in-model': self.SQL_FIELD_MISSING_IN_MODEL,
|
|
'index-missing-in-db': self.SQL_INDEX_MISSING_IN_DB,
|
|
'index-missing-in-model': self.SQL_INDEX_MISSING_IN_MODEL,
|
|
'unique-missing-in-db': self.SQL_UNIQUE_MISSING_IN_DB,
|
|
'unique-missing-in-model': self.SQL_UNIQUE_MISSING_IN_MODEL,
|
|
'field-type-differ': self.SQL_FIELD_TYPE_DIFFER,
|
|
'field-parameter-differ': self.SQL_FIELD_PARAMETER_DIFFER,
|
|
'notnull-differ': self.SQL_NOTNULL_DIFFER,
|
|
}
|
|
|
|
def load(self):
|
|
self.cursor = connection.cursor()
|
|
self.django_tables = self.introspection.django_table_names(only_existing=self.options['only_existing'])
|
|
# TODO: We are losing information about tables which are views here
|
|
self.db_tables = [table_info.name for table_info in self.introspection.get_table_list(self.cursor)]
|
|
|
|
if self.can_detect_notnull_differ:
|
|
self.load_null()
|
|
|
|
if self.can_detect_unsigned_differ:
|
|
self.load_unsigned()
|
|
|
|
def load_null(self):
|
|
raise NotImplementedError("load_null functions must be implemented if diff backend has 'can_detect_notnull_differ' set to True")
|
|
|
|
def load_unsigned(self):
|
|
raise NotImplementedError("load_unsigned function must be implemented if diff backend has 'can_detect_unsigned_differ' set to True")
|
|
|
|
def add_app_model_marker(self, app_label, model_name):
|
|
self.differences.append((app_label, model_name, []))
|
|
|
|
def add_difference(self, diff_type, *args):
|
|
assert diff_type in self.DIFF_TYPES, 'Unknown difference type'
|
|
self.differences[-1][-1].append((diff_type, args))
|
|
|
|
def get_data_types_reverse_override(self):
|
|
# type: () -> Dict[int, Union[str, Callable]]
|
|
return self.DATA_TYPES_REVERSE_OVERRIDE
|
|
|
|
def format_field_names(self, field_names):
|
|
return field_names
|
|
|
|
def sql_to_dict(self, query, param):
|
|
"""
|
|
Execute query and return a dict
|
|
|
|
sql_to_dict(query, param) -> list of dicts
|
|
|
|
code from snippet at https://www.djangosnippets.org/snippets/1383/
|
|
"""
|
|
cursor = connection.cursor()
|
|
cursor.execute(query, param)
|
|
fieldnames = [name[0] for name in cursor.description]
|
|
fieldnames = self.format_field_names(fieldnames)
|
|
result = []
|
|
for row in cursor.fetchall():
|
|
rowset = []
|
|
for field in zip(fieldnames, row):
|
|
rowset.append(field)
|
|
result.append(dict(rowset))
|
|
return result
|
|
|
|
def get_field_model_type(self, field):
|
|
return field.db_type(connection=connection)
|
|
|
|
def get_field_db_type_kwargs(self, current_kwargs, description, field=None, table_name=None, reverse_type=None):
|
|
return {}
|
|
|
|
def get_field_db_type(self, description, field=None, table_name=None):
|
|
# DB-API cursor.description
|
|
# (name, type_code, display_size, internal_size, precision, scale, null_ok) = description
|
|
type_code = description[1]
|
|
DATA_TYPES_REVERSE_OVERRIDE = self.get_data_types_reverse_override()
|
|
if type_code in DATA_TYPES_REVERSE_OVERRIDE:
|
|
reverse_type = DATA_TYPES_REVERSE_OVERRIDE[type_code]
|
|
else:
|
|
try:
|
|
reverse_type = self.introspection.get_field_type(type_code, description)
|
|
except KeyError:
|
|
reverse_type = self.get_field_db_type_lookup(type_code)
|
|
if not reverse_type:
|
|
# type_code not found in data_types_reverse map
|
|
key = (self.differences[-1][:2], description[:2])
|
|
if key not in self.unknown_db_fields:
|
|
self.unknown_db_fields[key] = 1
|
|
self.add_difference('comment', "Unknown database type for field '%s' (%s)" % (description[0], type_code))
|
|
return None
|
|
|
|
if callable(reverse_type):
|
|
reverse_type = reverse_type()
|
|
|
|
kwargs = {}
|
|
|
|
if isinstance(reverse_type, dict):
|
|
kwargs.update(reverse_type['kwargs'])
|
|
reverse_type = reverse_type['name']
|
|
|
|
if type_code == 16946 and field and getattr(field, 'geom_type', None) == 'POINT':
|
|
reverse_type = 'django.contrib.gis.db.models.fields.PointField'
|
|
|
|
if isinstance(reverse_type, tuple):
|
|
kwargs.update(reverse_type[1])
|
|
reverse_type = reverse_type[0]
|
|
|
|
if reverse_type == "CharField" and description[3]:
|
|
kwargs['max_length'] = description[3]
|
|
|
|
if reverse_type == "DecimalField":
|
|
kwargs['max_digits'] = description[4]
|
|
kwargs['decimal_places'] = description[5] and abs(description[5]) or description[5]
|
|
|
|
if description[6]:
|
|
kwargs['blank'] = True
|
|
if reverse_type not in ('TextField', 'CharField'):
|
|
kwargs['null'] = True
|
|
|
|
if field and getattr(field, 'geography', False):
|
|
kwargs['geography'] = True
|
|
|
|
if reverse_type == 'GeometryField':
|
|
geo_col = description[0]
|
|
# Getting a more specific field type and any additional parameters
|
|
# from the `get_geometry_type` routine for the spatial backend.
|
|
reverse_type, geo_params = self.introspection.get_geometry_type(table_name, geo_col)
|
|
if geo_params:
|
|
kwargs.update(geo_params)
|
|
reverse_type = 'django.contrib.gis.db.models.fields.%s' % reverse_type
|
|
|
|
extra_kwargs = self.get_field_db_type_kwargs(kwargs, description, field, table_name, reverse_type)
|
|
kwargs.update(extra_kwargs)
|
|
|
|
field_class = self.get_field_class(reverse_type)
|
|
field_db_type = field_class(**kwargs).db_type(connection=connection)
|
|
|
|
tablespace = field.db_tablespace
|
|
if not tablespace:
|
|
tablespace = "public"
|
|
if (tablespace, table_name, field.column) in self.unsigned and self.unsigned_suffix not in field_db_type:
|
|
field_db_type = '%s %s' % (field_db_type, self.unsigned_suffix)
|
|
|
|
return field_db_type
|
|
|
|
def get_field_db_type_lookup(self, type_code):
|
|
return None
|
|
|
|
def get_field_class(self, class_path):
|
|
if '.' in class_path:
|
|
module_path, package_name = class_path.rsplit('.', 1)
|
|
module = importlib.import_module(module_path)
|
|
return getattr(module, package_name)
|
|
|
|
return getattr(models, class_path)
|
|
|
|
def get_field_db_nullable(self, field, table_name):
|
|
tablespace = field.db_tablespace
|
|
if tablespace == "":
|
|
tablespace = "public"
|
|
attname = field.db_column or field.attname
|
|
return self.null.get((tablespace, table_name, attname), 'fixme')
|
|
|
|
def strip_parameters(self, field_type):
|
|
if field_type and field_type != 'double precision':
|
|
return field_type.split(" ")[0].split("(")[0].lower()
|
|
return field_type
|
|
|
|
def get_index_together(self, meta):
|
|
indexes_normalized = list(normalize_together(meta.index_together))
|
|
|
|
for idx in meta.indexes:
|
|
indexes_normalized.append(idx.fields)
|
|
|
|
return self.expand_together(indexes_normalized, meta)
|
|
|
|
def get_unique_together(self, meta):
|
|
unique_normalized = list(normalize_together(meta.unique_together))
|
|
|
|
for constraint in meta.constraints:
|
|
if isinstance(constraint, UniqueConstraint):
|
|
unique_normalized.append(constraint.fields)
|
|
|
|
return self.expand_together(unique_normalized, meta)
|
|
|
|
def expand_together(self, together, meta):
|
|
new_together = []
|
|
for fields in normalize_together(together):
|
|
new_together.append(
|
|
tuple(meta.get_field(field).attname for field in fields)
|
|
)
|
|
return new_together
|
|
|
|
def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, table_name, skip_list=None):
|
|
schema_editor = connection.SchemaEditorClass(connection)
|
|
for field in all_local_fields(meta):
|
|
if skip_list and field.attname in skip_list:
|
|
continue
|
|
if field.unique and meta.managed:
|
|
attname = field.db_column or field.attname
|
|
db_field_unique = table_indexes.get(attname, {}).get('unique')
|
|
if not db_field_unique and table_constraints:
|
|
db_field_unique = any(constraint['unique'] for contraint_name, constraint in table_constraints.items() if [attname] == constraint['columns'])
|
|
if attname in table_indexes and db_field_unique:
|
|
continue
|
|
|
|
index_name = schema_editor._create_index_name(table_name, [attname])
|
|
|
|
self.add_difference('unique-missing-in-db', table_name, [attname], index_name + "_uniq")
|
|
db_type = field.db_type(connection=connection)
|
|
if db_type.startswith('varchar'):
|
|
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' varchar_pattern_ops')
|
|
if db_type.startswith('text'):
|
|
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')
|
|
|
|
unique_together = self.get_unique_together(meta)
|
|
db_unique_columns = normalize_together([v['columns'] for v in table_constraints.values() if v['unique'] and not v['index']])
|
|
|
|
for unique_columns in unique_together:
|
|
if unique_columns in db_unique_columns:
|
|
continue
|
|
|
|
if skip_list and unique_columns in skip_list:
|
|
continue
|
|
|
|
index_name = schema_editor._create_index_name(table_name, unique_columns)
|
|
|
|
self.add_difference('unique-missing-in-db', table_name, unique_columns, index_name + "_uniq")
|
|
|
|
def find_unique_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
|
|
fields = dict([(field.column, field) for field in all_local_fields(meta)])
|
|
unique_together = self.get_unique_together(meta)
|
|
|
|
for constraint_name, constraint in table_constraints.items():
|
|
if not constraint['unique']:
|
|
continue
|
|
if constraint['index']:
|
|
# unique indexes are handled by find_index_missing_in_model
|
|
continue
|
|
|
|
columns = constraint['columns']
|
|
if len(columns) == 1:
|
|
field = fields.get(columns[0])
|
|
if field is None:
|
|
pass
|
|
elif field.unique:
|
|
continue
|
|
else:
|
|
if tuple(columns) in unique_together:
|
|
continue
|
|
|
|
self.add_difference('unique-missing-in-model', table_name, constraint_name)
|
|
|
|
def find_index_missing_in_db(self, meta, table_indexes, table_constraints, table_name):
|
|
schema_editor = connection.SchemaEditorClass(connection)
|
|
for field in all_local_fields(meta):
|
|
if field.db_index:
|
|
attname = field.db_column or field.attname
|
|
if attname not in table_indexes:
|
|
index_name = schema_editor._create_index_name(table_name, [attname])
|
|
self.add_difference('index-missing-in-db', table_name, [attname], index_name, '')
|
|
db_type = field.db_type(connection=connection)
|
|
if db_type.startswith('varchar'):
|
|
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' varchar_pattern_ops')
|
|
if db_type.startswith('text'):
|
|
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')
|
|
|
|
index_together = self.get_index_together(meta)
|
|
db_index_together = normalize_together([v['columns'] for v in table_constraints.values() if v['index'] and not v['unique']])
|
|
for columns in index_together:
|
|
if columns in db_index_together:
|
|
continue
|
|
index_name = schema_editor._create_index_name(table_name, columns)
|
|
self.add_difference('index-missing-in-db', table_name, columns, index_name + "_idx", '')
|
|
|
|
for index in meta.indexes:
|
|
if index.name not in table_constraints:
|
|
self.add_difference('index-missing-in-db', table_name, index.fields, index.name, '')
|
|
|
|
def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
|
|
fields = dict([(field.column, field) for field in all_local_fields(meta)])
|
|
meta_index_names = [idx.name for idx in meta.indexes]
|
|
index_together = self.get_index_together(meta)
|
|
|
|
for constraint_name, constraint in table_constraints.items():
|
|
if constraint_name in meta_index_names:
|
|
continue
|
|
if constraint['unique'] and not constraint['index']:
|
|
# unique constraints are handled by find_unique_missing_in_model
|
|
continue
|
|
|
|
columns = constraint['columns']
|
|
field = fields.get(columns[0])
|
|
if (constraint['unique'] and constraint['index']) or field is None:
|
|
# unique indexes do not exist in django ? only unique constraints
|
|
pass
|
|
elif len(columns) == 1:
|
|
if constraint['primary_key'] and field.primary_key:
|
|
continue
|
|
if constraint['foreign_key'] and isinstance(field, models.ForeignKey) and field.db_constraint:
|
|
continue
|
|
if constraint['unique'] and field.unique:
|
|
continue
|
|
if constraint['index'] and constraint['type'] == 'idx' and constraint.get('orders') and field.unique:
|
|
# django automatically creates a _like varchar_pattern_ops/text_pattern_ops index see https://code.djangoproject.com/ticket/12234
|
|
# note: mysql does not have and/or introspect and fill the 'orders' attribute of constraint information
|
|
continue
|
|
if constraint['index'] and field.db_index:
|
|
continue
|
|
if constraint['check'] and field.db_check(connection=connection):
|
|
continue
|
|
if getattr(field, 'spatial_index', False):
|
|
continue
|
|
else:
|
|
if constraint['index'] and tuple(columns) in index_together:
|
|
continue
|
|
|
|
self.add_difference('index-missing-in-model', table_name, constraint_name)
|
|
|
|
def find_field_missing_in_model(self, fieldmap, table_description, table_name):
|
|
for row in table_description:
|
|
if row[0] not in fieldmap:
|
|
self.add_difference('field-missing-in-model', table_name, row[0])
|
|
|
|
def find_field_missing_in_db(self, fieldmap, table_description, table_name):
|
|
db_fields = [row[0] for row in table_description]
|
|
for field_name, field in fieldmap.items():
|
|
if field_name not in db_fields:
|
|
field_output = []
|
|
|
|
if field.remote_field:
|
|
field_output.extend([field.remote_field.model._meta.db_table, field.remote_field.model._meta.get_field(field.remote_field.field_name).column])
|
|
op = 'fkey-missing-in-db'
|
|
else:
|
|
op = 'field-missing-in-db'
|
|
field_output.append(field.db_type(connection=connection))
|
|
if self.options['include_defaults'] and field.has_default():
|
|
field_output.append('DEFAULT %s' % field.get_prep_value(field.get_default()))
|
|
if not field.null:
|
|
field_output.append('NOT NULL')
|
|
self.add_difference(op, table_name, field_name, *field_output)
|
|
self.new_db_fields.add((table_name, field_name))
|
|
|
|
def find_field_type_differ(self, meta, table_description, table_name, func=None):
|
|
db_fields = dict([(row[0], row) for row in table_description])
|
|
for field in all_local_fields(meta):
|
|
if field.name not in db_fields:
|
|
continue
|
|
description = db_fields[field.name]
|
|
|
|
model_type = self.get_field_model_type(field)
|
|
db_type = self.get_field_db_type(description, field, table_name)
|
|
|
|
# use callback function if defined
|
|
if func:
|
|
model_type, db_type = func(field, description, model_type, db_type)
|
|
|
|
if not self.strip_parameters(db_type) == self.strip_parameters(model_type):
|
|
self.add_difference('field-type-differ', table_name, field.name, model_type, db_type)
|
|
|
|
def find_field_parameter_differ(self, meta, table_description, table_name, func=None):
|
|
db_fields = dict([(row[0], row) for row in table_description])
|
|
for field in all_local_fields(meta):
|
|
if field.name not in db_fields:
|
|
continue
|
|
description = db_fields[field.name]
|
|
|
|
model_type = self.get_field_model_type(field)
|
|
db_type = self.get_field_db_type(description, field, table_name)
|
|
|
|
if not self.strip_parameters(model_type) == self.strip_parameters(db_type):
|
|
continue
|
|
|
|
# use callback function if defined
|
|
if func:
|
|
model_type, db_type = func(field, description, model_type, db_type)
|
|
|
|
model_check = field.db_parameters(connection=connection)['check']
|
|
if ' CHECK' in db_type:
|
|
db_type, db_check = db_type.split(" CHECK", 1)
|
|
db_check = db_check.strip().lstrip("(").rstrip(")")
|
|
else:
|
|
db_check = None
|
|
|
|
if not model_type == db_type or not model_check == db_check:
|
|
self.add_difference('field-parameter-differ', table_name, field.name, model_type, db_type)
|
|
|
|
def find_field_notnull_differ(self, meta, table_description, table_name):
|
|
if not self.can_detect_notnull_differ:
|
|
return
|
|
|
|
for field in all_local_fields(meta):
|
|
attname = field.db_column or field.attname
|
|
if (table_name, attname) in self.new_db_fields:
|
|
continue
|
|
null = self.get_field_db_nullable(field, table_name)
|
|
if field.null != null:
|
|
action = field.null and 'DROP' or 'SET'
|
|
self.add_difference('notnull-differ', table_name, attname, action)
|
|
|
|
def get_constraints(self, cursor, table_name, introspection):
|
|
return {}
|
|
|
|
def find_differences(self):
|
|
if self.options['all_applications']:
|
|
self.add_app_model_marker(None, None)
|
|
for table in self.db_tables:
|
|
if table not in self.django_tables and table not in self.IGNORE_MISSING_TABLES:
|
|
self.add_difference('table-missing-in-model', table)
|
|
|
|
cur_app_label = None
|
|
for app_model in self.app_models:
|
|
meta = app_model._meta
|
|
table_name = meta.db_table
|
|
app_label = meta.app_label
|
|
|
|
if not self.options['include_proxy_models'] and meta.proxy:
|
|
continue
|
|
|
|
if cur_app_label != app_label:
|
|
# Marker indicating start of difference scan for this table_name
|
|
self.add_app_model_marker(app_label, app_model.__name__)
|
|
|
|
if table_name not in self.db_tables:
|
|
# Table is missing from database
|
|
self.add_difference('table-missing-in-db', table_name)
|
|
continue
|
|
|
|
if hasattr(self.introspection, 'get_constraints'):
|
|
table_constraints = self.introspection.get_constraints(self.cursor, table_name)
|
|
else:
|
|
table_constraints = self.get_constraints(self.cursor, table_name, self.introspection)
|
|
|
|
fieldmap = dict([(field.db_column or field.get_attname(), field) for field in all_local_fields(meta)])
|
|
|
|
# add ordering field if model uses order_with_respect_to
|
|
if meta.order_with_respect_to:
|
|
fieldmap['_order'] = ORDERING_FIELD
|
|
|
|
try:
|
|
table_description = self.introspection.get_table_description(self.cursor, table_name)
|
|
except Exception as e:
|
|
self.add_difference('error', 'unable to introspect table: %s' % str(e).strip())
|
|
transaction.rollback() # reset transaction
|
|
continue
|
|
|
|
# map table_constraints into table_indexes
|
|
table_indexes = {}
|
|
for contraint_name, dct in table_constraints.items():
|
|
|
|
columns = dct['columns']
|
|
if len(columns) == 1:
|
|
table_indexes[columns[0]] = {
|
|
'primary_key': dct['primary_key'],
|
|
'unique': dct['unique'],
|
|
'type': dct.get('type'),
|
|
'contraint_name': contraint_name,
|
|
}
|
|
|
|
# Fields which are defined in database but not in model
|
|
# 1) find: 'unique-missing-in-model'
|
|
self.find_unique_missing_in_model(meta, table_indexes, table_constraints, table_name)
|
|
# 2) find: 'index-missing-in-model'
|
|
self.find_index_missing_in_model(meta, table_indexes, table_constraints, table_name)
|
|
# 3) find: 'field-missing-in-model'
|
|
self.find_field_missing_in_model(fieldmap, table_description, table_name)
|
|
|
|
# Fields which are defined in models but not in database
|
|
# 4) find: 'field-missing-in-db'
|
|
self.find_field_missing_in_db(fieldmap, table_description, table_name)
|
|
# 5) find: 'unique-missing-in-db'
|
|
self.find_unique_missing_in_db(meta, table_indexes, table_constraints, table_name)
|
|
# 6) find: 'index-missing-in-db'
|
|
self.find_index_missing_in_db(meta, table_indexes, table_constraints, table_name)
|
|
|
|
# Fields which have a different type or parameters
|
|
# 7) find: 'type-differs'
|
|
self.find_field_type_differ(meta, table_description, table_name)
|
|
# 8) find: 'type-parameter-differs'
|
|
self.find_field_parameter_differ(meta, table_description, table_name)
|
|
# 9) find: 'field-notnull'
|
|
self.find_field_notnull_differ(meta, table_description, table_name)
|
|
self.has_differences = max([len(diffs) for _app_label, _model_name, diffs in self.differences])
|
|
|
|
def print_diff(self, style=no_style()):
|
|
""" Print differences to stdout """
|
|
if self.options['sql']:
|
|
self.print_diff_sql(style)
|
|
else:
|
|
self.print_diff_text(style)
|
|
|
|
def print_diff_text(self, style):
|
|
if not self.can_detect_notnull_differ:
|
|
self.stdout.write(style.NOTICE("# Detecting notnull changes not implemented for this database backend"))
|
|
self.stdout.write("")
|
|
|
|
if not self.can_detect_unsigned_differ:
|
|
self.stdout.write(style.NOTICE("# Detecting unsigned changes not implemented for this database backend"))
|
|
self.stdout.write("")
|
|
|
|
cur_app_label = None
|
|
for app_label, model_name, diffs in self.differences:
|
|
if not diffs:
|
|
continue
|
|
if not self.dense and app_label and cur_app_label != app_label:
|
|
self.stdout.write("%s %s" % (style.NOTICE("+ Application:"), style.SQL_TABLE(app_label)))
|
|
cur_app_label = app_label
|
|
if not self.dense and model_name:
|
|
self.stdout.write("%s %s" % (style.NOTICE("|-+ Differences for model:"), style.SQL_TABLE(model_name)))
|
|
for diff in diffs:
|
|
diff_type, diff_args = diff
|
|
text = self.DIFF_TEXTS[diff_type] % dict(
|
|
(str(i), style.SQL_TABLE(', '.join(e) if isinstance(e, (list, tuple)) else e))
|
|
for i, e in enumerate(diff_args)
|
|
)
|
|
text = "'".join(i % 2 == 0 and style.ERROR(e) or e for i, e in enumerate(text.split("'")))
|
|
if not self.dense:
|
|
self.stdout.write("%s %s" % (style.NOTICE("|--+"), text))
|
|
else:
|
|
if app_label:
|
|
self.stdout.write("%s %s %s %s %s" % (style.NOTICE("App"), style.SQL_TABLE(app_label), style.NOTICE('Model'), style.SQL_TABLE(model_name), text))
|
|
else:
|
|
self.stdout.write(text)
|
|
|
|
def print_diff_sql(self, style):
|
|
if not self.can_detect_notnull_differ:
|
|
self.stdout.write(style.NOTICE("-- Detecting notnull changes not implemented for this database backend"))
|
|
self.stdout.write("")
|
|
|
|
cur_app_label = None
|
|
qn = connection.ops.quote_name
|
|
if not self.has_differences:
|
|
if not self.dense:
|
|
self.stdout.write(style.SQL_KEYWORD("-- No differences"))
|
|
else:
|
|
self.stdout.write(style.SQL_KEYWORD("BEGIN;"))
|
|
for app_label, model_name, diffs in self.differences:
|
|
if not diffs:
|
|
continue
|
|
if not self.dense and cur_app_label != app_label:
|
|
self.stdout.write(style.NOTICE("-- Application: %s" % style.SQL_TABLE(app_label)))
|
|
cur_app_label = app_label
|
|
if not self.dense and model_name:
|
|
self.stdout.write(style.NOTICE("-- Model: %s" % style.SQL_TABLE(model_name)))
|
|
for diff in diffs:
|
|
diff_type, diff_args = diff
|
|
text = self.DIFF_SQL[diff_type](style, qn, diff_args)
|
|
if self.dense:
|
|
text = text.replace("\n\t", " ")
|
|
self.stdout.write(text)
|
|
self.stdout.write(style.SQL_KEYWORD("COMMIT;"))
|
|
|
|
|
|
class GenericSQLDiff(SQLDiff):
|
|
can_detect_notnull_differ = False
|
|
can_detect_unsigned_differ = False
|
|
|
|
def load_null(self):
|
|
pass
|
|
|
|
def load_unsigned(self):
|
|
pass
|
|
|
|
|
|
class MySQLDiff(SQLDiff):
|
|
can_detect_notnull_differ = True
|
|
can_detect_unsigned_differ = True
|
|
unsigned_suffix = 'UNSIGNED'
|
|
|
|
def load(self):
|
|
super().load()
|
|
self.auto_increment = set()
|
|
self.load_auto_increment()
|
|
|
|
def format_field_names(self, field_names):
|
|
return [f.lower() for f in field_names]
|
|
|
|
def load_null(self):
|
|
tablespace = 'public'
|
|
for table_name in self.db_tables:
|
|
result = self.sql_to_dict("""
|
|
SELECT column_name, is_nullable
|
|
FROM information_schema.columns
|
|
WHERE table_schema = DATABASE()
|
|
AND table_name = %s""", [table_name])
|
|
for table_info in result:
|
|
key = (tablespace, table_name, table_info['column_name'])
|
|
self.null[key] = table_info['is_nullable'] == 'YES'
|
|
|
|
def load_unsigned(self):
|
|
tablespace = 'public'
|
|
for table_name in self.db_tables:
|
|
result = self.sql_to_dict("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = DATABASE()
|
|
AND table_name = %s
|
|
AND column_type LIKE '%%unsigned'""", [table_name])
|
|
for table_info in result:
|
|
key = (tablespace, table_name, table_info['column_name'])
|
|
self.unsigned.add(key)
|
|
|
|
def load_auto_increment(self):
|
|
for table_name in self.db_tables:
|
|
result = self.sql_to_dict("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = DATABASE()
|
|
AND table_name = %s
|
|
AND extra = 'auto_increment'""", [table_name])
|
|
for table_info in result:
|
|
key = (table_name, table_info['column_name'])
|
|
self.auto_increment.add(key)
|
|
|
|
# All the MySQL hacks together create something of a problem
|
|
# Fixing one bug in MySQL creates another issue. So just keep in mind
|
|
# that this is way unreliable for MySQL atm.
|
|
def get_field_db_type(self, description, field=None, table_name=None):
|
|
db_type = super().get_field_db_type(description, field, table_name)
|
|
if not db_type:
|
|
return
|
|
if field:
|
|
# MySQL isn't really sure about char's and varchar's like sqlite
|
|
field_type = self.get_field_model_type(field)
|
|
|
|
# Fix char/varchar inconsistencies
|
|
if self.strip_parameters(field_type) == 'char' and self.strip_parameters(db_type) == 'varchar':
|
|
db_type = db_type.lstrip("var")
|
|
|
|
# They like to call bools various integer types and introspection makes that a integer
|
|
# just convert them all to bools
|
|
if self.strip_parameters(field_type) == 'bool':
|
|
if db_type == 'integer':
|
|
db_type = 'bool'
|
|
|
|
if (table_name, field.column) in self.auto_increment and 'AUTO_INCREMENT' not in db_type:
|
|
db_type += ' AUTO_INCREMENT'
|
|
return db_type
|
|
|
|
def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
|
|
fields = dict([(field.column, field) for field in all_local_fields(meta)])
|
|
meta_index_names = [idx.name for idx in meta.indexes]
|
|
index_together = self.get_index_together(meta)
|
|
unique_together = self.get_unique_together(meta)
|
|
|
|
for constraint_name, constraint in table_constraints.items():
|
|
if constraint_name in meta_index_names:
|
|
continue
|
|
if constraint['unique'] and not constraint['index']:
|
|
# unique constraints are handled by find_unique_missing_in_model
|
|
continue
|
|
|
|
columns = constraint['columns']
|
|
field = fields.get(columns[0])
|
|
|
|
# extra check removed from superclass here, otherwise function is the same
|
|
if len(columns) == 1:
|
|
if not field:
|
|
# both index and field are missing from the model
|
|
self.add_difference('index-missing-in-model', table_name, constraint_name)
|
|
continue
|
|
if constraint['primary_key'] and field.primary_key:
|
|
continue
|
|
if constraint['foreign_key'] and isinstance(field, models.ForeignKey) and field.db_constraint:
|
|
continue
|
|
if constraint['unique'] and field.unique:
|
|
continue
|
|
if constraint['index'] and constraint['type'] == 'idx' and constraint.get('orders') and field.unique:
|
|
# django automatically creates a _like varchar_pattern_ops/text_pattern_ops index see https://code.djangoproject.com/ticket/12234
|
|
# note: mysql does not have and/or introspect and fill the 'orders' attribute of constraint information
|
|
continue
|
|
if constraint['index'] and field.db_index:
|
|
continue
|
|
if constraint['check'] and field.db_check(connection=connection):
|
|
continue
|
|
if getattr(field, 'spatial_index', False):
|
|
continue
|
|
else:
|
|
if constraint['index'] and tuple(columns) in index_together:
|
|
continue
|
|
if constraint['index'] and constraint['unique'] and tuple(columns) in unique_together:
|
|
continue
|
|
|
|
self.add_difference('index-missing-in-model', table_name, constraint_name)
|
|
|
|
def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, table_name, skip_list=None):
|
|
|
|
schema_editor = connection.SchemaEditorClass(connection)
|
|
for field in all_local_fields(meta):
|
|
if skip_list and field.attname in skip_list:
|
|
continue
|
|
if field.unique and meta.managed:
|
|
attname = field.db_column or field.attname
|
|
db_field_unique = table_indexes.get(attname, {}).get('unique')
|
|
if not db_field_unique and table_constraints:
|
|
db_field_unique = any(constraint['unique'] for contraint_name, constraint in table_constraints.items() if [attname] == constraint['columns'])
|
|
if attname in table_indexes and db_field_unique:
|
|
continue
|
|
|
|
index_name = schema_editor._create_index_name(table_name, [attname])
|
|
|
|
self.add_difference('unique-missing-in-db', table_name, [attname], index_name + "_uniq")
|
|
db_type = field.db_type(connection=connection)
|
|
if db_type.startswith('varchar'):
|
|
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' varchar_pattern_ops')
|
|
if db_type.startswith('text'):
|
|
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')
|
|
|
|
unique_together = self.get_unique_together(meta)
|
|
|
|
# This comparison changed from superclass - otherwise function is the same
|
|
db_unique_columns = normalize_together([v['columns'] for v in table_constraints.values() if v['unique']])
|
|
|
|
for unique_columns in unique_together:
|
|
if unique_columns in db_unique_columns:
|
|
continue
|
|
|
|
if skip_list and unique_columns in skip_list:
|
|
continue
|
|
|
|
index_name = schema_editor._create_index_name(table_name, unique_columns)
|
|
self.add_difference('unique-missing-in-db', table_name, unique_columns, index_name + "_uniq")
|
|
|
|
|
|
class SqliteSQLDiff(SQLDiff):
|
|
can_detect_notnull_differ = True
|
|
can_detect_unsigned_differ = False
|
|
|
|
def load_null(self):
|
|
for table_name in self.db_tables:
|
|
# sqlite does not support tablespaces
|
|
tablespace = "public"
|
|
# index, column_name, column_type, nullable, default_value
|
|
# see: https://www.sqlite.org/pragma.html#pragma_table_info
|
|
for table_info in self.sql_to_dict("PRAGMA table_info('%s');" % table_name, []):
|
|
key = (tablespace, table_name, table_info['name'])
|
|
self.null[key] = not table_info['notnull']
|
|
|
|
def load_unsigned(self):
|
|
pass
|
|
|
|
# Unique does not seem to be implied on Sqlite for Primary_key's
|
|
# if this is more generic among databases this might be usefull
|
|
# to add to the superclass's find_unique_missing_in_db method
|
|
def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, table_name, skip_list=None):
|
|
if skip_list is None:
|
|
skip_list = []
|
|
|
|
unique_columns = [field.db_column or field.attname for field in all_local_fields(meta) if field.unique]
|
|
|
|
for constraint in table_constraints.values():
|
|
columns = constraint['columns']
|
|
if len(columns) == 1:
|
|
column = columns[0]
|
|
if column in unique_columns and (constraint['unique'] or constraint['primary_key']):
|
|
skip_list.append(column)
|
|
|
|
unique_together = self.get_unique_together(meta)
|
|
db_unique_columns = normalize_together([v['columns'] for v in table_constraints.values() if v['unique']])
|
|
|
|
for unique_columns in unique_together:
|
|
if unique_columns in db_unique_columns:
|
|
skip_list.append(unique_columns)
|
|
|
|
super().find_unique_missing_in_db(meta, table_indexes, table_constraints, table_name, skip_list=skip_list)
|
|
|
|
# Finding Indexes by using the get_indexes dictionary doesn't seem to work
|
|
# for sqlite.
|
|
def find_index_missing_in_db(self, meta, table_indexes, table_constraints, table_name):
|
|
pass
|
|
|
|
def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
|
|
pass
|
|
|
|
def get_field_db_type(self, description, field=None, table_name=None):
|
|
db_type = super().get_field_db_type(description, field, table_name)
|
|
if not db_type:
|
|
return None
|
|
if field:
|
|
field_type = self.get_field_model_type(field)
|
|
# Fix char/varchar inconsistencies
|
|
if self.strip_parameters(field_type) == 'char' and self.strip_parameters(db_type) == 'varchar':
|
|
db_type = db_type.lstrip("var")
|
|
return db_type
|
|
|
|
|
|
class PostgresqlSQLDiff(SQLDiff):
|
|
can_detect_notnull_differ = True
|
|
can_detect_unsigned_differ = True
|
|
|
|
DATA_TYPES_REVERSE_NAME = {
|
|
'hstore': 'django.contrib.postgres.fields.HStoreField',
|
|
'jsonb': 'django.contrib.postgres.fields.JSONField',
|
|
}
|
|
|
|
# Hopefully in the future we can add constraint checking and other more
|
|
# advanced checks based on this database.
|
|
SQL_LOAD_CONSTRAINTS = """
|
|
SELECT nspname, relname, conname, attname, pg_get_constraintdef(pg_constraint.oid)
|
|
FROM pg_constraint
|
|
INNER JOIN pg_attribute ON pg_constraint.conrelid = pg_attribute.attrelid AND pg_attribute.attnum = any(pg_constraint.conkey)
|
|
INNER JOIN pg_class ON conrelid=pg_class.oid
|
|
INNER JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace
|
|
ORDER BY CASE WHEN contype='f' THEN 0 ELSE 1 END,contype,nspname,relname,conname;
|
|
"""
|
|
SQL_LOAD_NULL = """
|
|
SELECT nspname, relname, attname, attnotnull
|
|
FROM pg_attribute
|
|
INNER JOIN pg_class ON attrelid=pg_class.oid
|
|
INNER JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace;
|
|
"""
|
|
|
|
SQL_FIELD_TYPE_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ALTER'), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD("TYPE"), style.SQL_COLTYPE(args[2]))
|
|
SQL_FIELD_PARAMETER_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ALTER'), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD("TYPE"), style.SQL_COLTYPE(args[2]))
|
|
SQL_NOTNULL_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ALTER COLUMN'), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD(args[2]), style.SQL_KEYWORD('NOT NULL'))
|
|
|
|
def load(self):
|
|
super().load()
|
|
self.check_constraints = {}
|
|
self.load_constraints()
|
|
|
|
def load_null(self):
|
|
for dct in self.sql_to_dict(self.SQL_LOAD_NULL, []):
|
|
key = (dct['nspname'], dct['relname'], dct['attname'])
|
|
self.null[key] = not dct['attnotnull']
|
|
|
|
def load_unsigned(self):
|
|
# PostgreSQL does not support unsigned, so no columns are
|
|
# unsigned. Nothing to do.
|
|
pass
|
|
|
|
def load_constraints(self):
|
|
for dct in self.sql_to_dict(self.SQL_LOAD_CONSTRAINTS, []):
|
|
key = (dct['nspname'], dct['relname'], dct['attname'])
|
|
if 'CHECK' in dct['pg_get_constraintdef']:
|
|
self.check_constraints[key] = dct
|
|
|
|
def get_data_type_arrayfield(self, base_field):
|
|
return {
|
|
'name': 'django.contrib.postgres.fields.ArrayField',
|
|
'kwargs': {
|
|
'base_field': self.get_field_class(base_field)(),
|
|
},
|
|
}
|
|
|
|
def get_data_types_reverse_override(self):
|
|
return {
|
|
1042: 'CharField',
|
|
1000: lambda: self.get_data_type_arrayfield(base_field='BooleanField'),
|
|
1001: lambda: self.get_data_type_arrayfield(base_field='BinaryField'),
|
|
1002: lambda: self.get_data_type_arrayfield(base_field='CharField'),
|
|
1005: lambda: self.get_data_type_arrayfield(base_field='IntegerField'),
|
|
1006: lambda: self.get_data_type_arrayfield(base_field='IntegerField'),
|
|
1007: lambda: self.get_data_type_arrayfield(base_field='IntegerField'),
|
|
1009: lambda: self.get_data_type_arrayfield(base_field='CharField'),
|
|
1014: lambda: self.get_data_type_arrayfield(base_field='CharField'),
|
|
1015: lambda: self.get_data_type_arrayfield(base_field='CharField'),
|
|
1016: lambda: self.get_data_type_arrayfield(base_field='BigIntegerField'),
|
|
1017: lambda: self.get_data_type_arrayfield(base_field='FloatField'),
|
|
1021: lambda: self.get_data_type_arrayfield(base_field='FloatField'),
|
|
1022: lambda: self.get_data_type_arrayfield(base_field='FloatField'),
|
|
1115: lambda: self.get_data_type_arrayfield(base_field='DateTimeField'),
|
|
1185: lambda: self.get_data_type_arrayfield(base_field='DateTimeField'),
|
|
1231: lambda: self.get_data_type_arrayfield(base_field='DecimalField'),
|
|
# {'name': 'django.contrib.postgres.fields.ArrayField', 'kwargs': {'base_field': 'IntegerField'}},
|
|
1186: lambda: self.get_data_type_arrayfield(base_field='DurationField'),
|
|
# 1186: 'django.db.models.fields.DurationField',
|
|
3614: 'django.contrib.postgres.search.SearchVectorField',
|
|
3802: 'django.contrib.postgres.fields.JSONField',
|
|
}
|
|
|
|
def get_constraints(self, cursor, table_name, introspection):
|
|
"""
|
|
Find constraints for table
|
|
|
|
Backport of django's introspection.get_constraints(...)
|
|
"""
|
|
constraints = {}
|
|
# Loop over the key table, collecting things as constraints
|
|
# This will get PKs, FKs, and uniques, but not CHECK
|
|
cursor.execute("""
|
|
SELECT
|
|
kc.constraint_name,
|
|
kc.column_name,
|
|
c.constraint_type,
|
|
array(SELECT table_name::text || '.' || column_name::text FROM information_schema.constraint_column_usage WHERE constraint_name = kc.constraint_name)
|
|
FROM information_schema.key_column_usage AS kc
|
|
JOIN information_schema.table_constraints AS c ON
|
|
kc.table_schema = c.table_schema AND
|
|
kc.table_name = c.table_name AND
|
|
kc.constraint_name = c.constraint_name
|
|
WHERE
|
|
kc.table_schema = %s AND
|
|
kc.table_name = %s
|
|
""", ["public", table_name])
|
|
for constraint, column, kind, used_cols in cursor.fetchall():
|
|
# If we're the first column, make the record
|
|
if constraint not in constraints:
|
|
constraints[constraint] = {
|
|
"columns": [],
|
|
"primary_key": kind.lower() == "primary key",
|
|
"unique": kind.lower() in ["primary key", "unique"],
|
|
"foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None,
|
|
"check": False,
|
|
"index": False,
|
|
}
|
|
# Record the details
|
|
constraints[constraint]['columns'].append(column)
|
|
# Now get CHECK constraint columns
|
|
cursor.execute("""
|
|
SELECT kc.constraint_name, kc.column_name
|
|
FROM information_schema.constraint_column_usage AS kc
|
|
JOIN information_schema.table_constraints AS c ON
|
|
kc.table_schema = c.table_schema AND
|
|
kc.table_name = c.table_name AND
|
|
kc.constraint_name = c.constraint_name
|
|
WHERE
|
|
c.constraint_type = 'CHECK' AND
|
|
kc.table_schema = %s AND
|
|
kc.table_name = %s
|
|
""", ["public", table_name])
|
|
for constraint, column in cursor.fetchall():
|
|
# If we're the first column, make the record
|
|
if constraint not in constraints:
|
|
constraints[constraint] = {
|
|
"columns": [],
|
|
"primary_key": False,
|
|
"unique": False,
|
|
"foreign_key": None,
|
|
"check": True,
|
|
"index": False,
|
|
}
|
|
# Record the details
|
|
constraints[constraint]['columns'].append(column)
|
|
# Now get indexes
|
|
cursor.execute("""
|
|
SELECT
|
|
c2.relname,
|
|
ARRAY(
|
|
SELECT (SELECT attname FROM pg_catalog.pg_attribute WHERE attnum = i AND attrelid = c.oid)
|
|
FROM unnest(idx.indkey) i
|
|
),
|
|
idx.indisunique,
|
|
idx.indisprimary
|
|
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
|
|
pg_catalog.pg_index idx
|
|
WHERE c.oid = idx.indrelid
|
|
AND idx.indexrelid = c2.oid
|
|
AND c.relname = %s
|
|
""", [table_name])
|
|
for index, columns, unique, primary in cursor.fetchall():
|
|
if index not in constraints:
|
|
constraints[index] = {
|
|
"columns": list(columns),
|
|
"primary_key": primary,
|
|
"unique": unique,
|
|
"foreign_key": None,
|
|
"check": False,
|
|
"index": True,
|
|
}
|
|
return constraints
|
|
|
|
# def get_field_db_type_kwargs(self, current_kwargs, description, field=None, table_name=None, reverse_type=None):
|
|
# kwargs = {}
|
|
# if field and 'base_field' in current_kwargs:
|
|
# # find
|
|
# attname = field.db_column or field.attname
|
|
# introspect_db_type = self.sql_to_dict(
|
|
# """SELECT attname, format_type(atttypid, atttypmod) AS type
|
|
# FROM pg_attribute
|
|
# WHERE attrelid = %s::regclass
|
|
# AND attname = %s
|
|
# AND attnum > 0
|
|
# AND NOT attisdropped
|
|
# ORDER BY attnum;
|
|
# """,
|
|
# (table_name, attname)
|
|
# )[0]['type']
|
|
# # TODO: this gives the concrete type that the database uses, why not use this
|
|
# # much earlier in the process to compare to whatever django spits out as
|
|
# # the database type ?
|
|
# max_length = re.search("character varying\((\d+)\)\[\]", introspect_db_type)
|
|
# if max_length:
|
|
# kwargs['max_length'] = max_length[1]
|
|
# return kwargs
|
|
|
|
def get_field_db_type(self, description, field=None, table_name=None):
|
|
db_type = super().get_field_db_type(description, field, table_name)
|
|
if not db_type:
|
|
return
|
|
if field:
|
|
if db_type.endswith("[]"):
|
|
# TODO: This is a hack for array types. Ideally we either pass the correct
|
|
# constraints for the type in `get_data_type_arrayfield` which instantiates
|
|
# the array base_field or maybe even better restructure sqldiff entirely
|
|
# to be based around the concrete type yielded by the code below. That gives
|
|
# the complete type the database uses, why not use this much earlier in the
|
|
# process to compare to whatever django spits out as the desired database type ?
|
|
attname = field.db_column or field.attname
|
|
introspect_db_type = self.sql_to_dict(
|
|
"""SELECT attname, format_type(atttypid, atttypmod) AS type
|
|
FROM pg_attribute
|
|
WHERE attrelid = %s::regclass
|
|
AND attname = %s
|
|
AND attnum > 0
|
|
AND NOT attisdropped
|
|
ORDER BY attnum;
|
|
""",
|
|
(table_name, attname)
|
|
)[0]['type']
|
|
if introspect_db_type.startswith("character varying"):
|
|
introspect_db_type = introspect_db_type.replace("character varying", "varchar")
|
|
|
|
return introspect_db_type
|
|
|
|
if field.primary_key and isinstance(field, AutoField):
|
|
if db_type == 'integer':
|
|
db_type = 'serial'
|
|
elif db_type == 'bigint':
|
|
db_type = 'bigserial'
|
|
if table_name:
|
|
tablespace = field.db_tablespace
|
|
if tablespace == "":
|
|
tablespace = "public"
|
|
attname = field.db_column or field.attname
|
|
check_constraint = self.check_constraints.get((tablespace, table_name, attname), {}).get('pg_get_constraintdef', None)
|
|
if check_constraint:
|
|
check_constraint = check_constraint.replace("((", "(")
|
|
check_constraint = check_constraint.replace("))", ")")
|
|
check_constraint = '("'.join([')' in e and '" '.join(p.strip('"') for p in e.split(" ", 1)) or e for e in check_constraint.split("(")])
|
|
# TODO: might be more then one constraint in definition ?
|
|
db_type += ' ' + check_constraint
|
|
return db_type
|
|
|
|
def get_field_db_type_lookup(self, type_code):
|
|
try:
|
|
name = self.sql_to_dict("SELECT typname FROM pg_type WHERE typelem=%s;", [type_code])[0]['typname']
|
|
return self.DATA_TYPES_REVERSE_NAME.get(name.strip('_'))
|
|
except (IndexError, KeyError):
|
|
pass
|
|
|
|
"""
|
|
def find_field_type_differ(self, meta, table_description, table_name):
|
|
def callback(field, description, model_type, db_type):
|
|
if field.primary_key and db_type=='integer':
|
|
db_type = 'serial'
|
|
return model_type, db_type
|
|
super().find_field_type_differ(meta, table_description, table_name, callback)
|
|
"""
|
|
|
|
|
|
DATABASE_SQLDIFF_CLASSES = {
|
|
'postgis': PostgresqlSQLDiff,
|
|
'postgresql_psycopg2': PostgresqlSQLDiff,
|
|
'postgresql': PostgresqlSQLDiff,
|
|
'mysql': MySQLDiff,
|
|
'sqlite3': SqliteSQLDiff,
|
|
'oracle': GenericSQLDiff
|
|
}
|
|
|
|
|
|
class Command(BaseCommand):
|
|
help = """Prints the (approximated) difference between models and fields in the database for the given app name(s).
|
|
|
|
It indicates how columns in the database are different from the sql that would
|
|
be generated by Django. This command is not a database migration tool. (Though
|
|
it can certainly help) It's purpose is to show the current differences as a way
|
|
to check/debug ur models compared to the real database tables and columns."""
|
|
|
|
output_transaction = False
|
|
|
|
def add_arguments(self, parser):
|
|
super().add_arguments(parser)
|
|
parser.add_argument('app_label', nargs='*')
|
|
parser.add_argument(
|
|
'--all-applications', '-a', action='store_true',
|
|
default=False,
|
|
dest='all_applications',
|
|
help="Automaticly include all application from INSTALLED_APPS."
|
|
)
|
|
parser.add_argument(
|
|
'--not-only-existing', '-e', action='store_false',
|
|
default=True,
|
|
dest='only_existing',
|
|
help="Check all tables that exist in the database, not only tables that should exist based on models."
|
|
)
|
|
parser.add_argument(
|
|
'--dense-output', '-d', action='store_true', dest='dense_output',
|
|
default=False,
|
|
help="Shows the output in dense format, normally output is spreaded over multiple lines."
|
|
)
|
|
parser.add_argument(
|
|
'--output_text', '-t', action='store_false', dest='sql',
|
|
default=True,
|
|
help="Outputs the differences as descriptive text instead of SQL"
|
|
)
|
|
parser.add_argument(
|
|
'--include-proxy-models', action='store_true', dest='include_proxy_models',
|
|
default=False,
|
|
help="Include proxy models in the graph"
|
|
)
|
|
parser.add_argument(
|
|
'--include-defaults', action='store_true', dest='include_defaults',
|
|
default=False,
|
|
help="Include default values in SQL output (beta feature)"
|
|
)
|
|
parser.add_argument(
|
|
'--migrate-for-tests', action='store_true', dest='migrate_for_tests',
|
|
default=False,
|
|
help=argparse.SUPPRESS
|
|
)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.exit_code = 1
|
|
|
|
@signalcommand
|
|
def handle(self, *args, **options):
|
|
from django.conf import settings
|
|
|
|
app_labels = options['app_label']
|
|
engine = None
|
|
if hasattr(settings, 'DATABASES'):
|
|
engine = settings.DATABASES['default']['ENGINE']
|
|
else:
|
|
engine = settings.DATABASE_ENGINE
|
|
|
|
if engine == 'dummy':
|
|
# This must be the "dummy" database backend, which means the user
|
|
# hasn't set DATABASE_ENGINE.
|
|
raise CommandError("""Django doesn't know which syntax to use for your SQL statements,
|
|
because you haven't specified the DATABASE_ENGINE setting.
|
|
Edit your settings file and change DATABASE_ENGINE to something like 'postgresql' or 'mysql'.""")
|
|
|
|
if options['all_applications']:
|
|
app_models = apps.get_models(include_auto_created=True)
|
|
else:
|
|
if not app_labels:
|
|
raise CommandError('Enter at least one appname.')
|
|
|
|
if not isinstance(app_labels, (list, tuple, set)):
|
|
app_labels = [app_labels]
|
|
|
|
app_models = []
|
|
for app_label in app_labels:
|
|
app_config = apps.get_app_config(app_label)
|
|
app_models.extend(app_config.get_models(include_auto_created=True))
|
|
|
|
if not app_models:
|
|
raise CommandError('Unable to execute sqldiff no models founds.')
|
|
|
|
migrate_for_tests = options['migrate_for_tests']
|
|
if migrate_for_tests:
|
|
from django.core.management import call_command
|
|
call_command("migrate", *app_labels, no_input=True, run_syncdb=True)
|
|
|
|
if not engine:
|
|
engine = connection.__module__.split('.')[-2]
|
|
|
|
if '.' in engine:
|
|
engine = engine.split('.')[-1]
|
|
|
|
cls = DATABASE_SQLDIFF_CLASSES.get(engine, GenericSQLDiff)
|
|
sqldiff_instance = cls(app_models, options, stdout=self.stdout, stderr=self.stderr)
|
|
sqldiff_instance.load()
|
|
sqldiff_instance.find_differences()
|
|
if not sqldiff_instance.has_differences:
|
|
self.exit_code = 0
|
|
sqldiff_instance.print_diff(self.style)
|
|
|
|
def execute(self, *args, **options):
|
|
try:
|
|
super().execute(*args, **options)
|
|
except CommandError as e:
|
|
if options['traceback']:
|
|
raise
|
|
|
|
# self.stderr is not guaranteed to be set here
|
|
stderr = getattr(self, 'stderr', None)
|
|
if not stderr:
|
|
stderr = OutputWrapper(sys.stderr, self.style.ERROR)
|
|
stderr.write('%s: %s' % (e.__class__.__name__, e))
|
|
sys.exit(2)
|
|
|
|
def run_from_argv(self, argv):
|
|
super().run_from_argv(argv)
|
|
sys.exit(self.exit_code)
|