From 2a1c127daeb47dffee4fcf36611e5abc9102bae1 Mon Sep 17 00:00:00 2001
From: Timm Fitschen <t.fitschen@indiscale.com>
Date: Fri, 9 Oct 2020 17:24:23 +0200
Subject: [PATCH] WIP:

---
 djaosdb/base.py                 | 16 +++++++--
 djaosdb/caosdb_client.py        | 55 ++++++++++++++++++----------
 djaosdb/middleware.py           | 18 ++++++++++
 djaosdb/models/__init__.py      |  2 ++
 djaosdb/models/manager.py       | 29 +++++++++++++++
 djaosdb/sql2mongo/query.py      | 64 +++++++++++++++++++++------------
 djaosdb/sql2mongo/sql_tokens.py | 51 +++++++++++++-------------
 7 files changed, 165 insertions(+), 70 deletions(-)
 create mode 100644 djaosdb/middleware.py
 create mode 100644 djaosdb/models/manager.py

diff --git a/djaosdb/base.py b/djaosdb/base.py
index 0284938..ebdcfeb 100644
--- a/djaosdb/base.py
+++ b/djaosdb/base.py
@@ -19,6 +19,12 @@ import caosdb
 
 LOGGER = getLogger(__name__)
 
+class _DatabaseClient(BaseDatabaseClient):
+
+    def __init__(self, connection):
+        super().__init__(connection)
+        LOGGER.debug("initializing database client")
+
 class DatabaseWrapper(BaseDatabaseWrapper):
     """
     DatabaseWrapper for MongoDB using SQL replacements.
@@ -45,7 +51,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         'IPAddressField': caosdb.TEXT,
         'GenericIPAddressField': caosdb.TEXT,
         'NullBooleanField': caosdb.BOOLEAN,
-        'OneToOneField': caosdb.INTEGER,
+        'OneToOneField': caosdb.REFERENCE,
         'PositiveIntegerField': caosdb.INTEGER,
         'PositiveSmallIntegerField': caosdb.INTEGER,
         'SlugField': caosdb.TEXT,
@@ -56,7 +62,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         'GenericObjectIdField': 'objectId',
         'ObjectIdField': 'objectId',
         'EmbeddedField': 'object',
-        'ArrayField': 'array'
+        'ArrayField': caosdb.LIST,
     }
 
     data_types_suffix = {
@@ -86,7 +92,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
     SchemaEditorClass = DatabaseSchemaEditor
     Database = Database
 
-    client_class = BaseDatabaseClient
+    client_class = _DatabaseClient
     creation_class = DatabaseCreation
     features_class = DatabaseFeatures
     introspection_class = DatabaseIntrospection
@@ -97,6 +103,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         self.djaosdb_connection = None
         super().__init__(*args, **kwargs)
 
+
+    def configure(self, session, user):
+        LOGGER.debug("configuring caosdb for %s, %s", session, user)
+
     def is_usable(self):
         if self.connection is not None:
             return True
diff --git a/djaosdb/caosdb_client.py b/djaosdb/caosdb_client.py
index 9c147ec..de25d53 100644
--- a/djaosdb/caosdb_client.py
+++ b/djaosdb/caosdb_client.py
@@ -3,7 +3,7 @@ from logging import getLogger
 import caosdb
 from django.contrib import messages
 
-logger = getLogger(__name__)
+LOGGER = getLogger(__name__)
 
 
 class Result:
@@ -58,6 +58,11 @@ class UpdateManyResult(Result):
         self.modified_count = len(container) # todo
         self.matched_count = len(container)
 
+class DeleteManyResult(Result):
+
+    def __init__(self, container):
+        super(DeleteManyResult, self).__init__()
+        self.deleted_count = len(container)
 
 class InsertManyResult(Result):
     def __init__(self, container):
@@ -106,7 +111,7 @@ class DefaultCaosDBClientDelegate:
         self._caosdb.configure_connection(**kwargs)
 
     def _get_filter_clause(self, fil):
-        logger.debug("enter _get_filter_clause(%s)", fil)
+        LOGGER.debug("enter _get_filter_clause(%s)", fil)
         if not "type" in fil:
             raise NotImplementedError("_get_filter_clause(%s)", fil)
 
@@ -159,10 +164,17 @@ class DefaultCaosDBClientDelegate:
         res2 = self._caosdb.execute_query("SELECT name FROM RECORDTYPE")
         return [e.name for e in res1 + res2 if e.name is not None]
 
-    def create_record_type(self, name : str, properties : list):
+    def create_record_type(self, name : str, properties : list, parents : list):
         c = self._caosdb.Container()
         rt = self._caosdb.RecordType(name)
         c.append(rt)
+        for p in parents:
+            name = p["name"]
+            if name not in self.cached_record_types:
+                # no way of dealing with this situation yet.
+                # wait and let the server throw an error.
+                pass
+            rt.add_parent(name)
         for p in properties:
             name = p["name"]
             datatype = p["datatype"]
@@ -191,7 +203,7 @@ class DefaultCaosDBClientDelegate:
                     if prop["name"] in property_names:
                         new_rec.add_property(**prop)
                     else:
-                        logger.warning("%s has not been stored due to changes to the schema.", prop["name"])
+                        LOGGER.warning("%s has not been stored due to changes to the schema.", prop["name"])
             c.append(new_rec)
         c.insert(unique=False)
         return InsertManyResult(c)
@@ -209,7 +221,6 @@ class DefaultCaosDBClientDelegate:
         res.update()
         return UpdateManyResult(res)
 
-
     def add_foreign_key(self, record_type, *args, **kwargs):
         c = self._caosdb.Container()
         rt = self._caosdb.RecordType(record_type).retrieve()
@@ -232,7 +243,7 @@ class DefaultCaosDBClientDelegate:
         query = self._generate_query(record_type, sort, projection, filter,
                                      count)
 
-        logger.debug("execute_query(%s)", query)
+        LOGGER.debug("execute_query(%s)", query)
         res = self._caosdb.execute_query(query)
         if count:
             return CountResult(res, count)
@@ -251,12 +262,18 @@ class DefaultCaosDBClientDelegate:
                                            list(sort.keys())) + " FROM"
         return f'{prefix} RECORD "{record_type}"{filter_clause}'
 
+    def delete_many(self, record_type, filter):
+        container = self._find(record_type, filter=filter)
+        if len(container) > 0:
+            container.delete()
+        return DeleteManyResult(container)
+
 
 
 class CaosDBClient:
     def __init__(self, enforce_schema=True,
                  delegate=DefaultCaosDBClientDelegate, **kwargs):
-        logger.info("New CaosDBClient: %s", kwargs)
+        LOGGER.info("New CaosDBClient: %s", kwargs)
         self.enforce_schema = enforce_schema
         self._delegate = delegate(**kwargs)
 
@@ -277,14 +294,14 @@ class CaosDBClient:
         -------
         record_types : list of str
         """
-        logger.debug("list_record_type_names")
+        LOGGER.debug("list_record_type_names")
         return self._delegate.list_record_type_names()
 
     def close(self):
         """close the connection to a CaosDB Server"""
-        logger.debug("close")
+        LOGGER.debug("close")
 
-    def create_record_type(self, name : str, properties : list):
+    def create_record_type(self, name : str, properties : list, parents: list):
         """create_record_type
 
         Parameters
@@ -299,8 +316,8 @@ class CaosDBClient:
         -------
         None
         """
-        logger.debug("create_record_type(%s, %s)", name, properties)
-        self._delegate.create_record_type(name, properties)
+        LOGGER.debug("create_record_type(%s, %s, %s)", name, properties, parents)
+        self._delegate.create_record_type(name, properties, parents)
 
     def create_index(self, *args, **kwargs):
         """create_index
@@ -329,7 +346,7 @@ class CaosDBClient:
         Returns
         -------
         """
-        logger.debug("add_foreign_key(%s, %s)", args, kwargs)
+        LOGGER.debug("add_foreign_key(%s, %s)", args, kwargs)
         return self._delegate.add_foreign_key(*args, **kwargs)
 
     def update_one(self, *args, **kwargs):
@@ -359,7 +376,7 @@ class CaosDBClient:
         Returns
         -------
         """
-        logger.debug("update_many(%s, %s)", args, kwargs)
+        LOGGER.debug("update_many(%s, %s)", args, kwargs)
         return self._delegate.update_many(*args, **kwargs)
 
     def drop_record_type(self, *args, **kwargs):
@@ -389,7 +406,7 @@ class CaosDBClient:
         Returns
         -------
         """
-        logger.debug("aggregate(%s, %s)", args, kwargs)
+        LOGGER.debug("aggregate(%s, %s)", args, kwargs)
         return self._delegate.aggregate(*args, **kwargs)
 
     def insert_many(self, record_type : str, records : list):
@@ -407,7 +424,7 @@ class CaosDBClient:
         -------
         result : InsertManyResult
         """
-        logger.debug("insert_many(%s, %s)", record_type, records)
+        LOGGER.debug("insert_many(%s, %s)", record_type, records)
         return self._delegate.insert_many(record_type, records)
 
     def update(self, *args, **kwargs):
@@ -452,8 +469,8 @@ class CaosDBClient:
         Returns
         -------
         """
-        # TODO
-        raise Exception("NOT IMPLEMENTED")
+        LOGGER.debug("delete_many(%s, %s)", args, kwargs)
+        return self._delegate.delete_many(*args, **kwargs)
 
     def drop_index(self, *args, **kwargs):
         """drop_index
@@ -483,5 +500,5 @@ class CaosDBClient:
         -------
         result : FindResult
         """
-        logger.debug("find(%s, %s)", args, kwargs)
+        LOGGER.debug("find(%s, %s)", args, kwargs)
         return self._delegate.find(*args, **kwargs)
diff --git a/djaosdb/middleware.py b/djaosdb/middleware.py
new file mode 100644
index 0000000..b5d33f2
--- /dev/null
+++ b/djaosdb/middleware.py
@@ -0,0 +1,18 @@
+from django.db import connections
+from logging import getLogger
+import threading
+
+LOGGER = getLogger(__name__)
+
+class DjaosDBSessionMiddleware:
+    def __init__(self, get_response):
+        self.get_response = get_response
+
+    def __call__(self, request):
+        connections["caosdb"].configure(request.session, request.user)
+        response = self.get_response(request)
+
+        # close connection
+        connections["caosdb"].close()
+
+        return response
diff --git a/djaosdb/models/__init__.py b/djaosdb/models/__init__.py
index b939929..069a3da 100644
--- a/djaosdb/models/__init__.py
+++ b/djaosdb/models/__init__.py
@@ -1,5 +1,6 @@
 from django.db.models import __all__ as django_models
 from django.db.models import *
+from .manager import DjaosdbManager
 
 from .fields import (
     ArrayField, DjongoManager,
@@ -8,6 +9,7 @@ from .fields import (
 )
 
 __all__ = django_models + [
+    'DjaosdbManager',
     'DjongoManager', 'ArrayField',
     'EmbeddedField', 'ArrayReferenceField', 'ObjectIdField',
     'GenericObjectIdField', 'JSONField'
diff --git a/djaosdb/models/manager.py b/djaosdb/models/manager.py
new file mode 100644
index 0000000..8912a33
--- /dev/null
+++ b/djaosdb/models/manager.py
@@ -0,0 +1,29 @@
+from logging import getLogger
+from django.db.models import Manager
+
+LOGGER = getLogger(__name__)
+
+class DjaosdbManager(Manager):
+
+    def __init__(self):
+        super().__init__()
+        LOGGER.debug("################# init")
+        self._db = "caosdb"
+
+    @property
+    def db(self):
+        db = super().db()
+        LOGGER.debug("################################################## db = %s", db)
+        return db
+
+    def get_queryset(self):
+        qs = super().get_queryset()
+        LOGGER.debug("################################## dir = %s", dir(qs))
+        LOGGER.debug("################################## db = %s", qs.db)
+        return qs
+
+    def db_manager(self):
+        LOGGER.debug("################################## db_manager _db = %s", self._db)
+        return super().db_manager()
+
+
diff --git a/djaosdb/sql2mongo/query.py b/djaosdb/sql2mongo/query.py
index 75220cc..93ee7aa 100644
--- a/djaosdb/sql2mongo/query.py
+++ b/djaosdb/sql2mongo/query.py
@@ -13,7 +13,7 @@ from dataclasses import dataclass, field as dataclass_field
 # from caosdb.cursor import Cursor as BasicCursor
 # from caosdb.database import Database
 # from caosdb.errors import OperationFailure, CollectionInvalid
-from caosdb import CaosDBException
+from caosdb import CaosDBException, REFERENCE
 from sqlparse import parse as sqlparse
 from sqlparse import tokens
 from sqlparse.sql import (Identifier, Parenthesis, Where, Statement)
@@ -734,34 +734,52 @@ class CreateQuery(DDLQuery):
                                  f' for column definition: {statement}')
 
         # collect properties
+        properties, parents = self._get_properties_and_parents_from_column_defs(
+            SQLColumnDef.sql2col_defs(tok.value))
+
+        self.connection.create_record_type(name=record_type,
+                                           properties=properties,
+                                           parents=parents)
+        logger.debug('Created record_type: {}'.format(record_type))
+
+    def _get_properties_and_parents_from_column_defs(self, cols):
         properties = []
-        for col in SQLColumnDef.sql2col_defs(tok.value):
+        parents = []
+        for col in cols:
             if isinstance(col, SQLColumnConstraint):
                 print_warn('column CONSTRAINTS')
-            else:
-                field = col.name
-                if field in ['id', "name", "description"]:
-                    continue
+                continue
 
-                properties.append({"name": field,
-                                   "datatype": col.data_type
-                                  })
+            field = col.name
+            if field in ['id', "name", "description", "unit"]:
+                # id, name, description, unit can be ignored as they already
+                # exist in any caosdb server instance.
+                continue
 
-                if SQLColumnDef.autoincrement in col.col_constraints:
-                    print_warn("AUTO INCREMENT")
+            # if field.endswith("_ptr_id"):
+                # # and col.data_type == REFERENCE:
+                # # these are the parents.
+                # parents.append({"name": field})
 
-                if SQLColumnDef.primarykey in col.col_constraints:
-                    print_warn("PRIMARY KEY other than id")
 
-                if SQLColumnDef.unique in col.col_constraints:
-                    print_warn("UNIQUE INDEX")
+            properties.append({"name": field,
+                               "datatype": col.data_type
+                              })
 
-                if (SQLColumnDef.not_null in col.col_constraints or
-                        SQLColumnDef.null in col.col_constraints):
-                    print_warn('NULL, NOT NULL column validation check')
+            if SQLColumnDef.autoincrement in col.col_constraints:
+                print_warn("AUTO INCREMENT")
+
+            if SQLColumnDef.primarykey in col.col_constraints:
+                print_warn("PRIMARY KEY other than id")
+
+            if SQLColumnDef.unique in col.col_constraints:
+                print_warn("UNIQUE INDEX")
+
+            if (SQLColumnDef.not_null in col.col_constraints or
+                    SQLColumnDef.null in col.col_constraints):
+                print_warn('NULL, NOT NULL column validation check')
+        return properties, parents
 
-        self.connection.create_record_type(record_type, properties)
-        logger.debug('Created record_type: {}'.format(record_type))
 
     def parse(self):
         statement = SQLStatement(self.statement)
@@ -785,7 +803,7 @@ class DeleteQuery(DMLQuery):
 
     def parse(self):
         statement = SQLStatement(self.statement)
-        self.kw = kw = {'filter': {}}
+        self.kw = kw = {}
         statement.skip(4)
         sql_token = SQLToken.token2sql(statement.next(), self)
         self.left_table = sql_token.table
@@ -793,11 +811,11 @@ class DeleteQuery(DMLQuery):
         tok = statement.next()
         if isinstance(tok, Where):
             where = WhereConverter(self, statement)
-            kw.update(where.to_mongo())
+            kw["filter"] = where.to_mongo()
 
     def execute(self):
         db_con = self.connection
-        self.result = db_con[self.left_table].delete_many(**self.kw)
+        self.result = db_con.delete_many(self.left_table, **self.kw)
         logger.debug('delete_many: {}'.format(self.result.deleted_count))
 
     def count(self):
diff --git a/djaosdb/sql2mongo/sql_tokens.py b/djaosdb/sql2mongo/sql_tokens.py
index 7fb8e54..284a005 100644
--- a/djaosdb/sql2mongo/sql_tokens.py
+++ b/djaosdb/sql2mongo/sql_tokens.py
@@ -376,31 +376,32 @@ class SQLColumnDef:
                             data_type=data_type,
                             col_constraints=col_constraints)
 
-    @classmethod
-    def statement2col_defs(cls, token: Token):
-        from djaosdb.base import DatabaseWrapper
-        supported_data_types = set(DatabaseWrapper.data_types.values())
-
-        defs = token.value.strip('()').split(',')
-        for col in defs:
-            col = col.strip()
-            name, other = col.split(' ', 1)
-            if name == 'CONSTRAINT':
-                yield SQLColumnConstraint()
-            else:
-                if col[0] != '"':
-                    raise SQLDecodeError('Column identifier not quoted')
-                name, other = col[1:].split('"', 1)
-                other = other.strip()
-
-                data_type, constraint_sql = other.split(' ', 1)
-                if data_type not in supported_data_types:
-                    raise NotSupportedError(f'Data of type: {data_type}')
-
-                col_constraints = set(SQLColumnDef._get_constraints(constraint_sql))
-                yield SQLColumnDef(name=name,
-                                   data_type=data_type,
-                                   col_constraints=col_constraints)
+    # UNUSED CODE?
+    # @classmethod
+    # def statement2col_defs(cls, token: Token):
+        # from djaosdb.base import DatabaseWrapper
+        # supported_data_types = set(DatabaseWrapper.data_types.values())
+
+        # defs = token.value.strip('()').split(',')
+        # for col in defs:
+            # col = col.strip()
+            # name, other = col.split(' ', 1)
+            # if name == 'CONSTRAINT':
+                # yield SQLColumnConstraint()
+            # else:
+                # if col[0] != '"':
+                    # raise SQLDecodeError('Column identifier not quoted')
+                # name, other = col[1:].split('"', 1)
+                # other = other.strip()
+
+                # data_type, constraint_sql = other.split(' ', 1)
+                # if data_type not in supported_data_types:
+                    # raise NotSupportedError(f'Data of type: {data_type}')
+
+                # col_constraints = set(SQLColumnDef._get_constraints(constraint_sql))
+                # yield SQLColumnDef(name=name,
+                                   # data_type=data_type,
+                                   # col_constraints=col_constraints)
 
 
 class SQLColumnConstraint(SQLColumnDef):
-- 
GitLab