From d76d56cee35a17f37e9db61a52f06ff1b8ec9feb Mon Sep 17 00:00:00 2001
From: Timm Fitschen <t.fitschen@indiscale.com>
Date: Mon, 5 Oct 2020 11:38:29 +0200
Subject: [PATCH] WIP: djaosdb

---
 djaosdb/base.py                 |   2 +-
 djaosdb/caosdb_client.py        | 109 ++++++++++++++++++++++++++++----
 djaosdb/sql2mongo/converters.py |  53 +++++++++-------
 djaosdb/sql2mongo/functions.py  |   2 +-
 djaosdb/sql2mongo/query.py      |  86 ++++++++++++++++---------
 5 files changed, 186 insertions(+), 66 deletions(-)

diff --git a/djaosdb/base.py b/djaosdb/base.py
index ea38a04..0284938 100644
--- a/djaosdb/base.py
+++ b/djaosdb/base.py
@@ -46,7 +46,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         'GenericIPAddressField': caosdb.TEXT,
         'NullBooleanField': caosdb.BOOLEAN,
         'OneToOneField': caosdb.INTEGER,
-        'PositiveIntegerField': 'long',
+        'PositiveIntegerField': caosdb.INTEGER,
         'PositiveSmallIntegerField': caosdb.INTEGER,
         'SlugField': caosdb.TEXT,
         'SmallIntegerField': caosdb.INTEGER,
diff --git a/djaosdb/caosdb_client.py b/djaosdb/caosdb_client.py
index 94d033f..d323ba7 100644
--- a/djaosdb/caosdb_client.py
+++ b/djaosdb/caosdb_client.py
@@ -1,6 +1,7 @@
-import caosdb
+import abc
 from logging import getLogger
-from .mockup_delegator import MockupDelegatorConnection
+import caosdb
+from django.contrib import messages
 
 logger = getLogger(__name__)
 
@@ -23,18 +24,41 @@ class Result:
         # TODO
         return self
 
+class CountResult(Result):
+
+    def __init__(self, count, column_name):
+        super(CountResult, self).__init__()
+        self._results = [{column_name: count}]
+
+    def __iter__(self):
+        return iter(self._results)
+
 class FindResult(Result):
 
-    def __init__(self, rows, columns):
-        self.alive = True
+    def __init__(self, rows, columns, sort=None, limit=None):
+        super(FindResult, self).__init__()
         self._results = []
+        self._sort = sort
 
         for row in rows:
             self._results.append(dict(zip(columns, row)))
 
+        # todo sort
+
+        if limit is not None:
+            self._results = self._results[0:limit]
+
     def __iter__(self):
         return iter(self._results)
 
+class UpdateManyResult(Result):
+
+    def __init__(self, container):
+        super(UpdateManyResult, self).__init__()
+        self.modified_count = len(container) # todo
+        self.matched_count = len(container)
+
+
 class InsertManyResult(Result):
     def __init__(self, container):
         super(InsertManyResult, self).__init__()
@@ -77,9 +101,19 @@ class DefaultCaosDBClientDelegate:
         self._connection = caosdb.configure_connection(**kwargs)
 
     def _get_filter_clause(self, fil):
+        logger.debug("enter _get_filter_clause(%s)", fil)
         if "$and" in fil:
             components = [self._get_filter_clause(comps) for comps in fil["$and"]]
             return " AND".join(components)
+        if "reference" in fil:
+            ref = fil["reference"]
+            result = ref["str"]
+            if "sub" in ref:
+                if not "$in" in ref["sub"]:
+                    raise NotImplementedError("_get_filter_clause(%s)", fil)
+                components = [" ID = " + str(val) for val in ref["sub"]["$in"]]
+                result += " WITH (" + " OR".join(components) + " )"
+            return result
         if "p" in fil:
             n = ""
             if fil["negation"] is True:
@@ -90,15 +124,17 @@ class DefaultCaosDBClientDelegate:
             return f' {n}{p}{o}"{v}"'
         raise NotImplementedError("_get_filter_clause(%s)", fil)
 
-    def find(self, record_type, *args, **kwargs):
+    def _find(self, record_type, *args, **kwargs):
         filter_clause = ""
         if "filter" in kwargs:
             fil = kwargs["filter"]
-            filter_clause = "WITH " + self._get_filter_clause(fil)
+            filter_clause = " WITH " + self._get_filter_clause(fil)
 
         query = f'FIND RECORD "{record_type}"{filter_clause}'
-        res = caosdb.execute_query(query)
+        return caosdb.execute_query(query)
 
+    def find(self, record_type, *args, **kwargs):
+        res = self._find(record_type, *args, **kwargs)
         projection = kwargs["projection"]
         rows = res.get_property_values(*projection)
         return FindResult(rows, projection)
@@ -108,8 +144,9 @@ class DefaultCaosDBClientDelegate:
         return [e.name for e in res if e.name is not None]
 
     def list_property_names(self):
-        res = caosdb.execute_query("SELECT name FROM PROPERTY")
-        return [e.name for e in res if e.name is not None]
+        res1 = caosdb.execute_query("SELECT name FROM PROPERTY")
+        res2 = caosdb.execute_query("SELECT name FROM RECORDTYPE")
+        return [e.name for e in res1 + res2 if e.name is not None]
 
     def create_record_type(self, name : str, properties : list):
         c = caosdb.Container()
@@ -128,6 +165,7 @@ class DefaultCaosDBClientDelegate:
 
     def insert_many(self, record_type : str, records : list):
         c = caosdb.Container()
+        property_names = self.list_property_names()
         for rec in records:
             name, description = None, None
             if "name" in rec:
@@ -139,11 +177,28 @@ class DefaultCaosDBClientDelegate:
             new_rec.add_parent(record_type)
             if "properties" in rec:
                 for prop in rec["properties"]:
-                    new_rec.add_property(**prop)
+                    if prop["name"] in property_names:
+                        new_rec.add_property(**prop)
+                    else:
+                        logger.warning("%s has not been stored due to changes to the schema.", prop["name"])
             c.append(new_rec)
         c.insert(unique=False)
         return InsertManyResult(c)
 
+    def update_many(self, record_type: str, filter, update):
+        res = self._find(record_type, filter=filter)
+        update_set = update["$set"]
+        properties = update_set.keys()
+        for e in res:
+            for p in properties:
+                val = update_set[p]
+                if hasattr(val, "isoformat"):
+                    val = val.isoformat()
+                e.get_property(p).value = val
+        res.update()
+        return UpdateManyResult(res)
+
+
     def add_foreign_key(self, record_type, *args, **kwargs):
         c = caosdb.Container()
         rt = caosdb.RecordType(record_type).retrieve()
@@ -158,6 +213,36 @@ class DefaultCaosDBClientDelegate:
         rt.get_property(kwargs["property"][0][0]).datatype=referenced
         c.update()
 
+    def aggregate(self, record_type, *args, **kwargs):
+        kwargs.update(args[0])
+        return self._aggregate(record_type, **kwargs)
+
+    def _aggregate(self, record_type, reference=None, sort=None,
+                   projection=None, limit=None, filter=None, count=False):
+        prefix = 'FIND'
+        filter_clause = ""
+        if reference is not None:
+            fil = {"reference": reference}
+            filter_clause = " WHICH " + self._get_filter_clause(fil)
+        if filter is not None:
+            if reference is not None:
+                filter_clause += " AND"
+            filter_clause += " WITH " + self._get_filter_clause(filter)
+        if count:
+            prefix = "COUNT"
+        elif sort is not None:
+            prefix = 'SELECT ' + ", ".join(list(projection) +
+                                           list(sort.keys())) + " FROM"
+        query = f'{prefix} RECORD "{record_type}"{filter_clause}'
+
+        logger.debug("execute_query(%s)", query)
+        res = caosdb.execute_query(query)
+        if count:
+            return CountResult(res, count)
+        rows = res.get_property_values(*projection)
+        return FindResult(rows, projection, sort, limit)
+
+
 
 class CaosDBClient:
     def __init__(self, enforce_schema=True,
@@ -265,8 +350,8 @@ class CaosDBClient:
         Returns
         -------
         """
-        # TODO
-        raise Exception("NOT IMPLEMENTED")
+        logger.debug("update_many(%s, %s)", args, kwargs)
+        return self._delegate.update_many(*args, **kwargs)
 
     def drop_record_type(self, *args, **kwargs):
         """drop_record_type
diff --git a/djaosdb/sql2mongo/converters.py b/djaosdb/sql2mongo/converters.py
index d87f0f1..f5c711c 100644
--- a/djaosdb/sql2mongo/converters.py
+++ b/djaosdb/sql2mongo/converters.py
@@ -61,20 +61,21 @@ class ColumnSelectConverter(Converter):
 class AggColumnSelectConverter(ColumnSelectConverter):
 
     def to_mongo(self):
-        project = {}
 
+        if isinstance(self.sql_tokens[0], CountFuncAll):
+            return self.sql_tokens[0].to_mongo()
         if any(isinstance(tok, SQLFunc) for tok in self.sql_tokens):
             # A SELECT func without groupby clause still needs a groupby
             # in MongoDB
             return self._using_group_by()
 
         elif isinstance(self.sql_tokens[0], SQLConstIdentifier):
+            project = {}
             project[self.sql_tokens[0].alias] = self.sql_tokens[0].to_mongo()
         else:
-            for selected in self.sql_tokens:
-                project[selected.field] = True
+            project = [selected.field for selected in self.sql_tokens]
 
-        return [{'$project': project}]
+        return {'projection': project}
 
     def _using_group_by(self):
         group = {
@@ -126,12 +127,6 @@ class WhereConverter(Converter):
         return {'filter': self.op.to_mongo()}
 
 
-class AggWhereConverter(WhereConverter):
-
-    def to_mongo(self):
-        return {'$match': self.op.to_mongo()}
-
-
 class JoinConverter(Converter):
 
     @abc.abstractmethod
@@ -165,6 +160,25 @@ class JoinConverter(Converter):
             self.left_column = sql.right_column
             self.right_column = sql.left_column
 
+    def _reference(self):
+        if self.left_column == "id":
+            return {
+                "back_reference": {
+                "p": self.right_column,
+                "v": self.right_table,
+                "negation": False,
+                "str":
+                    f"IS REFERENCED BY {self.right_table} AS {self.right_column}"
+                }}
+        elif self.right_column == "id":
+            return { "reference": {
+                "p": self.left_column,
+                "v": self.right_table,
+                "negation": False,
+                "str": f"REFERENCES {self.right_table} AS {self.left_column}"
+            }}
+        return None
+
     def _lookup(self):
         if self.left_table == self.query.left_table:
             local_field = self.left_column
@@ -194,6 +208,11 @@ class InnerJoinConverter(JoinConverter):
         else:
             match_field = f'{self.left_table}.{self.left_column}'
 
+        # try to construct a references/referenced-by filter
+        reference = self._reference()
+        if reference is not None:
+            return reference
+
         lookup = self._lookup()
         pipeline = [
             {
@@ -263,12 +282,6 @@ class LimitConverter(Converter):
         return {'limit': self.limit}
 
 
-class AggLimitConverter(LimitConverter):
-
-    def to_mongo(self):
-        return {'$limit': self.limit}
-
-
 class OrderConverter(Converter):
     def __init__(self, *args):
         self.columns: List[SQLIdentifier] = []
@@ -315,7 +328,7 @@ class AggOrderConverter(OrderConverter):
         for tok in self.columns:
             sort[tok.field] = tok.order
 
-        return {'$sort': sort}
+        return {'sort': sort}
 
 
 class _Tokens2Id:
@@ -484,9 +497,3 @@ class OffsetConverter(Converter):
 
     def to_mongo(self):
         return {'skip': self.offset}
-
-
-class AggOffsetConverter(OffsetConverter):
-
-    def to_mongo(self):
-        return {'$skip': self.offset}
diff --git a/djaosdb/sql2mongo/functions.py b/djaosdb/sql2mongo/functions.py
index 76c551e..fe1c02c 100644
--- a/djaosdb/sql2mongo/functions.py
+++ b/djaosdb/sql2mongo/functions.py
@@ -92,7 +92,7 @@ class CountFuncAll(CountFunc):
         super().__init__(*args)
 
     def to_mongo(self):
-        return {'$sum': 1}
+        return {'count': self.alias}
 
 
 class CountFuncSingle(CountFunc, SingleParamFunc):
diff --git a/djaosdb/sql2mongo/query.py b/djaosdb/sql2mongo/query.py
index 3153599..ada223c 100644
--- a/djaosdb/sql2mongo/query.py
+++ b/djaosdb/sql2mongo/query.py
@@ -16,20 +16,23 @@ from dataclasses import dataclass, field as dataclass_field
 from caosdb import CaosDBException
 from sqlparse import parse as sqlparse
 from sqlparse import tokens
-from sqlparse.sql import (
-    Identifier, Parenthesis,
-    Where,
-    Statement)
+from sqlparse.sql import (Identifier, Parenthesis, Where, Statement)
+Values = Statement
+try:
+    from sqlparse.sql import Values as Values
+except ImportError:
+    pass # compatibility sqlparse 0.2.4 - 0.3.1
 
 from ..exceptions import SQLDecodeError, MigrationError, print_warn
 from .functions import SQLFunc
 from .sql_tokens import (SQLToken, SQLStatement, SQLIdentifier,
                          AliasableToken, SQLConstIdentifier, SQLColumnDef, SQLColumnConstraint)
-from .converters import (
-    ColumnSelectConverter, AggColumnSelectConverter, FromConverter, WhereConverter,
-    AggWhereConverter, InnerJoinConverter, OuterJoinConverter, LimitConverter, AggLimitConverter, OrderConverter,
-    SetConverter, AggOrderConverter, DistinctConverter, NestedInQueryConverter, GroupbyConverter, OffsetConverter,
-    AggOffsetConverter, HavingConverter)
+from .converters import (ColumnSelectConverter, AggColumnSelectConverter,
+                         FromConverter, WhereConverter, InnerJoinConverter,
+                         OuterJoinConverter, LimitConverter, OrderConverter,
+                         SetConverter, DistinctConverter,
+                         NestedInQueryConverter, GroupbyConverter,
+                         OffsetConverter, AggOrderConverter, HavingConverter)
 
 from djaosdb import base
 logger = getLogger(__name__)
@@ -197,43 +200,59 @@ class SelectQuery(DQLQuery):
         return False
 
     def _make_pipeline(self):
-        pipeline = []
-        for join in self.joins:
-            pipeline.extend(join.to_mongo())
+        aggregation = {}
+        joins = None
+        where = None
+        if self.joins:
+            joins = [join.to_mongo() for join in self.joins]
 
         if self.nested_query:
-            pipeline.extend(self.nested_query.to_mongo())
+            aggregation["nested"] = self.nested_query.to_mongo()
 
         if self.where:
-            self.where.__class__ = AggWhereConverter
-            pipeline.append(self.where.to_mongo())
+            where = self.where.to_mongo()
 
         if self.groupby:
-            pipeline.extend(self.groupby.to_mongo())
+            aggregation["groupby"] = self.groupby.to_mongo()
 
         if self.having:
-            pipeline.append(self.having.to_mongo())
+            aggregation["having"] = self.having.to_mongo()
 
         if self.distinct:
-            pipeline.extend(self.distinct.to_mongo())
+            aggregation["distinct"] = self.distinct.to_mongo()
 
         if self.order:
             self.order.__class__ = AggOrderConverter
-            pipeline.append(self.order.to_mongo())
+            aggregation.update(self.order.to_mongo())
 
         if self.offset:
-            self.offset.__class__ = AggOffsetConverter
-            pipeline.append(self.offset.to_mongo())
+            aggregation.update(self.offset.to_mongo())
 
         if self.limit:
-            self.limit.__class__ = AggLimitConverter
-            pipeline.append(self.limit.to_mongo())
+            aggregation.update(self.limit.to_mongo())
 
         if self._needs_column_selection():
             self.selected_columns.__class__ = AggColumnSelectConverter
-            pipeline.extend(self.selected_columns.to_mongo())
-
-        return pipeline
+            select = self.selected_columns.to_mongo()
+            aggregation.update(select)
+
+        ## merge join and where clause
+        if joins is not None and where is not None:
+            if len(joins) == 1 and "reference" in joins[0]:
+                ref = joins[0]["reference"]
+                p = ref["p"]
+                if p in where["filter"]:
+                    ref["sub"] = where["filter"][p]
+                    aggregation.update(joins[0])
+                    joins = None
+                    where = None
+        if joins is not None:
+            aggregation["joins"] = joins
+        if where is not None:
+            aggregation.update(where)
+
+
+        return aggregation
 
     def _needs_column_selection(self):
         return not(self.distinct or self.groupby) and self.selected_columns
@@ -361,7 +380,7 @@ class InsertQuery(DMLQuery):
         tok = statement.next()
         self._cols = [token.column for token in SQLToken.tokens2sql(tok[1], self)]
 
-    def _fill_values(self, statement: SQLStatement):
+    def _fill_values(self, statement: U[SQLStatement, Values]):
         for tok in statement:
             if isinstance(tok, Parenthesis):
                 placeholder = SQLToken.token2sql(tok, self)
@@ -372,8 +391,17 @@ class InsertQuery(DMLQuery):
                     else:
                         values.append(index)
                 self._values.append(values)
-            elif not tok.match(tokens.Keyword, 'VALUES'):
-                raise SQLDecodeError
+            elif tok.match(tokens.Keyword, 'Values'):
+                # sqlparse==0.2.4
+                continue
+            elif tok.match(tokens.Whitespace, r"\s", True):
+                # sqlparse==0.3.1
+                continue
+            elif isinstance(tok, Values):
+                # sqlparse==0.3.1
+                self._fill_values(statement=tok)
+            else:
+                raise SQLDecodeError(tok)
 
     def execute(self):
         records = []
-- 
GitLab