Skip to content
Snippets Groups Projects
Verified Commit 2a1c127d authored by Timm Fitschen's avatar Timm Fitschen
Browse files

WIP:

parent dbaf2326
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,12 @@ import caosdb ...@@ -19,6 +19,12 @@ import caosdb
LOGGER = getLogger(__name__) LOGGER = getLogger(__name__)
class _DatabaseClient(BaseDatabaseClient):
def __init__(self, connection):
super().__init__(connection)
LOGGER.debug("initializing database client")
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
""" """
DatabaseWrapper for MongoDB using SQL replacements. DatabaseWrapper for MongoDB using SQL replacements.
...@@ -45,7 +51,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): ...@@ -45,7 +51,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'IPAddressField': caosdb.TEXT, 'IPAddressField': caosdb.TEXT,
'GenericIPAddressField': caosdb.TEXT, 'GenericIPAddressField': caosdb.TEXT,
'NullBooleanField': caosdb.BOOLEAN, 'NullBooleanField': caosdb.BOOLEAN,
'OneToOneField': caosdb.INTEGER, 'OneToOneField': caosdb.REFERENCE,
'PositiveIntegerField': caosdb.INTEGER, 'PositiveIntegerField': caosdb.INTEGER,
'PositiveSmallIntegerField': caosdb.INTEGER, 'PositiveSmallIntegerField': caosdb.INTEGER,
'SlugField': caosdb.TEXT, 'SlugField': caosdb.TEXT,
...@@ -56,7 +62,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): ...@@ -56,7 +62,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'GenericObjectIdField': 'objectId', 'GenericObjectIdField': 'objectId',
'ObjectIdField': 'objectId', 'ObjectIdField': 'objectId',
'EmbeddedField': 'object', 'EmbeddedField': 'object',
'ArrayField': 'array' 'ArrayField': caosdb.LIST,
} }
data_types_suffix = { data_types_suffix = {
...@@ -86,7 +92,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): ...@@ -86,7 +92,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
SchemaEditorClass = DatabaseSchemaEditor SchemaEditorClass = DatabaseSchemaEditor
Database = Database Database = Database
client_class = BaseDatabaseClient client_class = _DatabaseClient
creation_class = DatabaseCreation creation_class = DatabaseCreation
features_class = DatabaseFeatures features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection introspection_class = DatabaseIntrospection
...@@ -97,6 +103,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): ...@@ -97,6 +103,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.djaosdb_connection = None self.djaosdb_connection = None
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def configure(self, session, user):
LOGGER.debug("configuring caosdb for %s, %s", session, user)
def is_usable(self): def is_usable(self):
if self.connection is not None: if self.connection is not None:
return True return True
......
...@@ -3,7 +3,7 @@ from logging import getLogger ...@@ -3,7 +3,7 @@ from logging import getLogger
import caosdb import caosdb
from django.contrib import messages from django.contrib import messages
logger = getLogger(__name__) LOGGER = getLogger(__name__)
class Result: class Result:
...@@ -58,6 +58,11 @@ class UpdateManyResult(Result): ...@@ -58,6 +58,11 @@ class UpdateManyResult(Result):
self.modified_count = len(container) # todo self.modified_count = len(container) # todo
self.matched_count = len(container) self.matched_count = len(container)
class DeleteManyResult(Result):
def __init__(self, container):
super(DeleteManyResult, self).__init__()
self.deleted_count = len(container)
class InsertManyResult(Result): class InsertManyResult(Result):
def __init__(self, container): def __init__(self, container):
...@@ -106,7 +111,7 @@ class DefaultCaosDBClientDelegate: ...@@ -106,7 +111,7 @@ class DefaultCaosDBClientDelegate:
self._caosdb.configure_connection(**kwargs) self._caosdb.configure_connection(**kwargs)
def _get_filter_clause(self, fil): def _get_filter_clause(self, fil):
logger.debug("enter _get_filter_clause(%s)", fil) LOGGER.debug("enter _get_filter_clause(%s)", fil)
if not "type" in fil: if not "type" in fil:
raise NotImplementedError("_get_filter_clause(%s)", fil) raise NotImplementedError("_get_filter_clause(%s)", fil)
...@@ -159,10 +164,17 @@ class DefaultCaosDBClientDelegate: ...@@ -159,10 +164,17 @@ class DefaultCaosDBClientDelegate:
res2 = self._caosdb.execute_query("SELECT name FROM RECORDTYPE") res2 = self._caosdb.execute_query("SELECT name FROM RECORDTYPE")
return [e.name for e in res1 + res2 if e.name is not None] return [e.name for e in res1 + res2 if e.name is not None]
def create_record_type(self, name : str, properties : list): def create_record_type(self, name : str, properties : list, parents : list):
c = self._caosdb.Container() c = self._caosdb.Container()
rt = self._caosdb.RecordType(name) rt = self._caosdb.RecordType(name)
c.append(rt) c.append(rt)
for p in parents:
name = p["name"]
if name not in self.cached_record_types:
# no way of dealing with this situation yet.
# wait and let the server throw an error.
pass
rt.add_parent(name)
for p in properties: for p in properties:
name = p["name"] name = p["name"]
datatype = p["datatype"] datatype = p["datatype"]
...@@ -191,7 +203,7 @@ class DefaultCaosDBClientDelegate: ...@@ -191,7 +203,7 @@ class DefaultCaosDBClientDelegate:
if prop["name"] in property_names: if prop["name"] in property_names:
new_rec.add_property(**prop) new_rec.add_property(**prop)
else: else:
logger.warning("%s has not been stored due to changes to the schema.", prop["name"]) LOGGER.warning("%s has not been stored due to changes to the schema.", prop["name"])
c.append(new_rec) c.append(new_rec)
c.insert(unique=False) c.insert(unique=False)
return InsertManyResult(c) return InsertManyResult(c)
...@@ -209,7 +221,6 @@ class DefaultCaosDBClientDelegate: ...@@ -209,7 +221,6 @@ class DefaultCaosDBClientDelegate:
res.update() res.update()
return UpdateManyResult(res) return UpdateManyResult(res)
def add_foreign_key(self, record_type, *args, **kwargs): def add_foreign_key(self, record_type, *args, **kwargs):
c = self._caosdb.Container() c = self._caosdb.Container()
rt = self._caosdb.RecordType(record_type).retrieve() rt = self._caosdb.RecordType(record_type).retrieve()
...@@ -232,7 +243,7 @@ class DefaultCaosDBClientDelegate: ...@@ -232,7 +243,7 @@ class DefaultCaosDBClientDelegate:
query = self._generate_query(record_type, sort, projection, filter, query = self._generate_query(record_type, sort, projection, filter,
count) count)
logger.debug("execute_query(%s)", query) LOGGER.debug("execute_query(%s)", query)
res = self._caosdb.execute_query(query) res = self._caosdb.execute_query(query)
if count: if count:
return CountResult(res, count) return CountResult(res, count)
...@@ -251,12 +262,18 @@ class DefaultCaosDBClientDelegate: ...@@ -251,12 +262,18 @@ class DefaultCaosDBClientDelegate:
list(sort.keys())) + " FROM" list(sort.keys())) + " FROM"
return f'{prefix} RECORD "{record_type}"{filter_clause}' return f'{prefix} RECORD "{record_type}"{filter_clause}'
def delete_many(self, record_type, filter):
container = self._find(record_type, filter=filter)
if len(container) > 0:
container.delete()
return DeleteManyResult(container)
class CaosDBClient: class CaosDBClient:
def __init__(self, enforce_schema=True, def __init__(self, enforce_schema=True,
delegate=DefaultCaosDBClientDelegate, **kwargs): delegate=DefaultCaosDBClientDelegate, **kwargs):
logger.info("New CaosDBClient: %s", kwargs) LOGGER.info("New CaosDBClient: %s", kwargs)
self.enforce_schema = enforce_schema self.enforce_schema = enforce_schema
self._delegate = delegate(**kwargs) self._delegate = delegate(**kwargs)
...@@ -277,14 +294,14 @@ class CaosDBClient: ...@@ -277,14 +294,14 @@ class CaosDBClient:
------- -------
record_types : list of str record_types : list of str
""" """
logger.debug("list_record_type_names") LOGGER.debug("list_record_type_names")
return self._delegate.list_record_type_names() return self._delegate.list_record_type_names()
def close(self): def close(self):
"""close the connection to a CaosDB Server""" """close the connection to a CaosDB Server"""
logger.debug("close") LOGGER.debug("close")
def create_record_type(self, name : str, properties : list): def create_record_type(self, name : str, properties : list, parents: list):
"""create_record_type """create_record_type
Parameters Parameters
...@@ -299,8 +316,8 @@ class CaosDBClient: ...@@ -299,8 +316,8 @@ class CaosDBClient:
------- -------
None None
""" """
logger.debug("create_record_type(%s, %s)", name, properties) LOGGER.debug("create_record_type(%s, %s, %s)", name, properties, parents)
self._delegate.create_record_type(name, properties) self._delegate.create_record_type(name, properties, parents)
def create_index(self, *args, **kwargs): def create_index(self, *args, **kwargs):
"""create_index """create_index
...@@ -329,7 +346,7 @@ class CaosDBClient: ...@@ -329,7 +346,7 @@ class CaosDBClient:
Returns Returns
------- -------
""" """
logger.debug("add_foreign_key(%s, %s)", args, kwargs) LOGGER.debug("add_foreign_key(%s, %s)", args, kwargs)
return self._delegate.add_foreign_key(*args, **kwargs) return self._delegate.add_foreign_key(*args, **kwargs)
def update_one(self, *args, **kwargs): def update_one(self, *args, **kwargs):
...@@ -359,7 +376,7 @@ class CaosDBClient: ...@@ -359,7 +376,7 @@ class CaosDBClient:
Returns Returns
------- -------
""" """
logger.debug("update_many(%s, %s)", args, kwargs) LOGGER.debug("update_many(%s, %s)", args, kwargs)
return self._delegate.update_many(*args, **kwargs) return self._delegate.update_many(*args, **kwargs)
def drop_record_type(self, *args, **kwargs): def drop_record_type(self, *args, **kwargs):
...@@ -389,7 +406,7 @@ class CaosDBClient: ...@@ -389,7 +406,7 @@ class CaosDBClient:
Returns Returns
------- -------
""" """
logger.debug("aggregate(%s, %s)", args, kwargs) LOGGER.debug("aggregate(%s, %s)", args, kwargs)
return self._delegate.aggregate(*args, **kwargs) return self._delegate.aggregate(*args, **kwargs)
def insert_many(self, record_type : str, records : list): def insert_many(self, record_type : str, records : list):
...@@ -407,7 +424,7 @@ class CaosDBClient: ...@@ -407,7 +424,7 @@ class CaosDBClient:
------- -------
result : InsertManyResult result : InsertManyResult
""" """
logger.debug("insert_many(%s, %s)", record_type, records) LOGGER.debug("insert_many(%s, %s)", record_type, records)
return self._delegate.insert_many(record_type, records) return self._delegate.insert_many(record_type, records)
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
...@@ -452,8 +469,8 @@ class CaosDBClient: ...@@ -452,8 +469,8 @@ class CaosDBClient:
Returns Returns
------- -------
""" """
# TODO LOGGER.debug("delete_many(%s, %s)", args, kwargs)
raise Exception("NOT IMPLEMENTED") return self._delegate.delete_many(*args, **kwargs)
def drop_index(self, *args, **kwargs): def drop_index(self, *args, **kwargs):
"""drop_index """drop_index
...@@ -483,5 +500,5 @@ class CaosDBClient: ...@@ -483,5 +500,5 @@ class CaosDBClient:
------- -------
result : FindResult result : FindResult
""" """
logger.debug("find(%s, %s)", args, kwargs) LOGGER.debug("find(%s, %s)", args, kwargs)
return self._delegate.find(*args, **kwargs) return self._delegate.find(*args, **kwargs)
from django.db import connections
from logging import getLogger
import threading
LOGGER = getLogger(__name__)
class DjaosDBSessionMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
connections["caosdb"].configure(request.session, request.user)
response = self.get_response(request)
# close connection
connections["caosdb"].close()
return response
from django.db.models import __all__ as django_models from django.db.models import __all__ as django_models
from django.db.models import * from django.db.models import *
from .manager import DjaosdbManager
from .fields import ( from .fields import (
ArrayField, DjongoManager, ArrayField, DjongoManager,
...@@ -8,6 +9,7 @@ from .fields import ( ...@@ -8,6 +9,7 @@ from .fields import (
) )
__all__ = django_models + [ __all__ = django_models + [
'DjaosdbManager',
'DjongoManager', 'ArrayField', 'DjongoManager', 'ArrayField',
'EmbeddedField', 'ArrayReferenceField', 'ObjectIdField', 'EmbeddedField', 'ArrayReferenceField', 'ObjectIdField',
'GenericObjectIdField', 'JSONField' 'GenericObjectIdField', 'JSONField'
......
from logging import getLogger
from django.db.models import Manager
LOGGER = getLogger(__name__)
class DjaosdbManager(Manager):
def __init__(self):
super().__init__()
LOGGER.debug("################# init")
self._db = "caosdb"
@property
def db(self):
db = super().db()
LOGGER.debug("################################################## db = %s", db)
return db
def get_queryset(self):
qs = super().get_queryset()
LOGGER.debug("################################## dir = %s", dir(qs))
LOGGER.debug("################################## db = %s", qs.db)
return qs
def db_manager(self):
LOGGER.debug("################################## db_manager _db = %s", self._db)
return super().db_manager()
...@@ -13,7 +13,7 @@ from dataclasses import dataclass, field as dataclass_field ...@@ -13,7 +13,7 @@ from dataclasses import dataclass, field as dataclass_field
# from caosdb.cursor import Cursor as BasicCursor # from caosdb.cursor import Cursor as BasicCursor
# from caosdb.database import Database # from caosdb.database import Database
# from caosdb.errors import OperationFailure, CollectionInvalid # from caosdb.errors import OperationFailure, CollectionInvalid
from caosdb import CaosDBException from caosdb import CaosDBException, REFERENCE
from sqlparse import parse as sqlparse from sqlparse import parse as sqlparse
from sqlparse import tokens from sqlparse import tokens
from sqlparse.sql import (Identifier, Parenthesis, Where, Statement) from sqlparse.sql import (Identifier, Parenthesis, Where, Statement)
...@@ -734,15 +734,34 @@ class CreateQuery(DDLQuery): ...@@ -734,15 +734,34 @@ class CreateQuery(DDLQuery):
f' for column definition: {statement}') f' for column definition: {statement}')
# collect properties # collect properties
properties, parents = self._get_properties_and_parents_from_column_defs(
SQLColumnDef.sql2col_defs(tok.value))
self.connection.create_record_type(name=record_type,
properties=properties,
parents=parents)
logger.debug('Created record_type: {}'.format(record_type))
def _get_properties_and_parents_from_column_defs(self, cols):
properties = [] properties = []
for col in SQLColumnDef.sql2col_defs(tok.value): parents = []
for col in cols:
if isinstance(col, SQLColumnConstraint): if isinstance(col, SQLColumnConstraint):
print_warn('column CONSTRAINTS') print_warn('column CONSTRAINTS')
else: continue
field = col.name field = col.name
if field in ['id', "name", "description"]: if field in ['id', "name", "description", "unit"]:
# id, name, description, unit can be ignored as they already
# exist in any caosdb server instance.
continue continue
# if field.endswith("_ptr_id"):
# # and col.data_type == REFERENCE:
# # these are the parents.
# parents.append({"name": field})
properties.append({"name": field, properties.append({"name": field,
"datatype": col.data_type "datatype": col.data_type
}) })
...@@ -759,9 +778,8 @@ class CreateQuery(DDLQuery): ...@@ -759,9 +778,8 @@ class CreateQuery(DDLQuery):
if (SQLColumnDef.not_null in col.col_constraints or if (SQLColumnDef.not_null in col.col_constraints or
SQLColumnDef.null in col.col_constraints): SQLColumnDef.null in col.col_constraints):
print_warn('NULL, NOT NULL column validation check') print_warn('NULL, NOT NULL column validation check')
return properties, parents
self.connection.create_record_type(record_type, properties)
logger.debug('Created record_type: {}'.format(record_type))
def parse(self): def parse(self):
statement = SQLStatement(self.statement) statement = SQLStatement(self.statement)
...@@ -785,7 +803,7 @@ class DeleteQuery(DMLQuery): ...@@ -785,7 +803,7 @@ class DeleteQuery(DMLQuery):
def parse(self): def parse(self):
statement = SQLStatement(self.statement) statement = SQLStatement(self.statement)
self.kw = kw = {'filter': {}} self.kw = kw = {}
statement.skip(4) statement.skip(4)
sql_token = SQLToken.token2sql(statement.next(), self) sql_token = SQLToken.token2sql(statement.next(), self)
self.left_table = sql_token.table self.left_table = sql_token.table
...@@ -793,11 +811,11 @@ class DeleteQuery(DMLQuery): ...@@ -793,11 +811,11 @@ class DeleteQuery(DMLQuery):
tok = statement.next() tok = statement.next()
if isinstance(tok, Where): if isinstance(tok, Where):
where = WhereConverter(self, statement) where = WhereConverter(self, statement)
kw.update(where.to_mongo()) kw["filter"] = where.to_mongo()
def execute(self): def execute(self):
db_con = self.connection db_con = self.connection
self.result = db_con[self.left_table].delete_many(**self.kw) self.result = db_con.delete_many(self.left_table, **self.kw)
logger.debug('delete_many: {}'.format(self.result.deleted_count)) logger.debug('delete_many: {}'.format(self.result.deleted_count))
def count(self): def count(self):
......
...@@ -376,31 +376,32 @@ class SQLColumnDef: ...@@ -376,31 +376,32 @@ class SQLColumnDef:
data_type=data_type, data_type=data_type,
col_constraints=col_constraints) col_constraints=col_constraints)
@classmethod # UNUSED CODE?
def statement2col_defs(cls, token: Token): # @classmethod
from djaosdb.base import DatabaseWrapper # def statement2col_defs(cls, token: Token):
supported_data_types = set(DatabaseWrapper.data_types.values()) # from djaosdb.base import DatabaseWrapper
# supported_data_types = set(DatabaseWrapper.data_types.values())
defs = token.value.strip('()').split(',')
for col in defs: # defs = token.value.strip('()').split(',')
col = col.strip() # for col in defs:
name, other = col.split(' ', 1) # col = col.strip()
if name == 'CONSTRAINT': # name, other = col.split(' ', 1)
yield SQLColumnConstraint() # if name == 'CONSTRAINT':
else: # yield SQLColumnConstraint()
if col[0] != '"': # else:
raise SQLDecodeError('Column identifier not quoted') # if col[0] != '"':
name, other = col[1:].split('"', 1) # raise SQLDecodeError('Column identifier not quoted')
other = other.strip() # name, other = col[1:].split('"', 1)
# other = other.strip()
data_type, constraint_sql = other.split(' ', 1)
if data_type not in supported_data_types: # data_type, constraint_sql = other.split(' ', 1)
raise NotSupportedError(f'Data of type: {data_type}') # if data_type not in supported_data_types:
# raise NotSupportedError(f'Data of type: {data_type}')
col_constraints = set(SQLColumnDef._get_constraints(constraint_sql))
yield SQLColumnDef(name=name, # col_constraints = set(SQLColumnDef._get_constraints(constraint_sql))
data_type=data_type, # yield SQLColumnDef(name=name,
col_constraints=col_constraints) # data_type=data_type,
# col_constraints=col_constraints)
class SQLColumnConstraint(SQLColumnDef): class SQLColumnConstraint(SQLColumnDef):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment