From d76d56cee35a17f37e9db61a52f06ff1b8ec9feb Mon Sep 17 00:00:00 2001 From: Timm Fitschen <t.fitschen@indiscale.com> Date: Mon, 5 Oct 2020 11:38:29 +0200 Subject: [PATCH] WIP: djaosdb --- djaosdb/base.py | 2 +- djaosdb/caosdb_client.py | 109 ++++++++++++++++++++++++++++---- djaosdb/sql2mongo/converters.py | 53 +++++++++------- djaosdb/sql2mongo/functions.py | 2 +- djaosdb/sql2mongo/query.py | 86 ++++++++++++++++--------- 5 files changed, 186 insertions(+), 66 deletions(-) diff --git a/djaosdb/base.py b/djaosdb/base.py index ea38a04..0284938 100644 --- a/djaosdb/base.py +++ b/djaosdb/base.py @@ -46,7 +46,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'GenericIPAddressField': caosdb.TEXT, 'NullBooleanField': caosdb.BOOLEAN, 'OneToOneField': caosdb.INTEGER, - 'PositiveIntegerField': 'long', + 'PositiveIntegerField': caosdb.INTEGER, 'PositiveSmallIntegerField': caosdb.INTEGER, 'SlugField': caosdb.TEXT, 'SmallIntegerField': caosdb.INTEGER, diff --git a/djaosdb/caosdb_client.py b/djaosdb/caosdb_client.py index 94d033f..d323ba7 100644 --- a/djaosdb/caosdb_client.py +++ b/djaosdb/caosdb_client.py @@ -1,6 +1,7 @@ -import caosdb +import abc from logging import getLogger -from .mockup_delegator import MockupDelegatorConnection +import caosdb +from django.contrib import messages logger = getLogger(__name__) @@ -23,18 +24,41 @@ class Result: # TODO return self +class CountResult(Result): + + def __init__(self, count, column_name): + super(CountResult, self).__init__() + self._results = [{column_name: count}] + + def __iter__(self): + return iter(self._results) + class FindResult(Result): - def __init__(self, rows, columns): - self.alive = True + def __init__(self, rows, columns, sort=None, limit=None): + super(FindResult, self).__init__() self._results = [] + self._sort = sort for row in rows: self._results.append(dict(zip(columns, row))) + # todo sort + + if limit is not None: + self._results = self._results[0:limit] + def __iter__(self): return iter(self._results) +class UpdateManyResult(Result): + + def __init__(self, container): + super(UpdateManyResult, self).__init__() + self.modified_count = len(container) # todo + self.matched_count = len(container) + + class InsertManyResult(Result): def __init__(self, container): super(InsertManyResult, self).__init__() @@ -77,9 +101,19 @@ class DefaultCaosDBClientDelegate: self._connection = caosdb.configure_connection(**kwargs) def _get_filter_clause(self, fil): + logger.debug("enter _get_filter_clause(%s)", fil) if "$and" in fil: components = [self._get_filter_clause(comps) for comps in fil["$and"]] return " AND".join(components) + if "reference" in fil: + ref = fil["reference"] + result = ref["str"] + if "sub" in ref: + if not "$in" in ref["sub"]: + raise NotImplementedError("_get_filter_clause(%s)", fil) + components = [" ID = " + str(val) for val in ref["sub"]["$in"]] + result += " WITH (" + " OR".join(components) + " )" + return result if "p" in fil: n = "" if fil["negation"] is True: @@ -90,15 +124,17 @@ class DefaultCaosDBClientDelegate: return f' {n}{p}{o}"{v}"' raise NotImplementedError("_get_filter_clause(%s)", fil) - def find(self, record_type, *args, **kwargs): + def _find(self, record_type, *args, **kwargs): filter_clause = "" if "filter" in kwargs: fil = kwargs["filter"] - filter_clause = "WITH " + self._get_filter_clause(fil) + filter_clause = " WITH " + self._get_filter_clause(fil) query = f'FIND RECORD "{record_type}"{filter_clause}' - res = caosdb.execute_query(query) + return caosdb.execute_query(query) + def find(self, record_type, *args, **kwargs): + res = self._find(record_type, *args, **kwargs) projection = kwargs["projection"] rows = res.get_property_values(*projection) return FindResult(rows, projection) @@ -108,8 +144,9 @@ class DefaultCaosDBClientDelegate: return [e.name for e in res if e.name is not None] def list_property_names(self): - res = caosdb.execute_query("SELECT name FROM PROPERTY") - return [e.name for e in res if e.name is not None] + res1 = caosdb.execute_query("SELECT name FROM PROPERTY") + res2 = caosdb.execute_query("SELECT name FROM RECORDTYPE") + return [e.name for e in res1 + res2 if e.name is not None] def create_record_type(self, name : str, properties : list): c = caosdb.Container() @@ -128,6 +165,7 @@ class DefaultCaosDBClientDelegate: def insert_many(self, record_type : str, records : list): c = caosdb.Container() + property_names = self.list_property_names() for rec in records: name, description = None, None if "name" in rec: @@ -139,11 +177,28 @@ class DefaultCaosDBClientDelegate: new_rec.add_parent(record_type) if "properties" in rec: for prop in rec["properties"]: - new_rec.add_property(**prop) + if prop["name"] in property_names: + new_rec.add_property(**prop) + else: + logger.warning("%s has not been stored due to changes to the schema.", prop["name"]) c.append(new_rec) c.insert(unique=False) return InsertManyResult(c) + def update_many(self, record_type: str, filter, update): + res = self._find(record_type, filter=filter) + update_set = update["$set"] + properties = update_set.keys() + for e in res: + for p in properties: + val = update_set[p] + if hasattr(val, "isoformat"): + val = val.isoformat() + e.get_property(p).value = val + res.update() + return UpdateManyResult(res) + + def add_foreign_key(self, record_type, *args, **kwargs): c = caosdb.Container() rt = caosdb.RecordType(record_type).retrieve() @@ -158,6 +213,36 @@ class DefaultCaosDBClientDelegate: rt.get_property(kwargs["property"][0][0]).datatype=referenced c.update() + def aggregate(self, record_type, *args, **kwargs): + kwargs.update(args[0]) + return self._aggregate(record_type, **kwargs) + + def _aggregate(self, record_type, reference=None, sort=None, + projection=None, limit=None, filter=None, count=False): + prefix = 'FIND' + filter_clause = "" + if reference is not None: + fil = {"reference": reference} + filter_clause = " WHICH " + self._get_filter_clause(fil) + if filter is not None: + if reference is not None: + filter_clause += " AND" + filter_clause += " WITH " + self._get_filter_clause(filter) + if count: + prefix = "COUNT" + elif sort is not None: + prefix = 'SELECT ' + ", ".join(list(projection) + + list(sort.keys())) + " FROM" + query = f'{prefix} RECORD "{record_type}"{filter_clause}' + + logger.debug("execute_query(%s)", query) + res = caosdb.execute_query(query) + if count: + return CountResult(res, count) + rows = res.get_property_values(*projection) + return FindResult(rows, projection, sort, limit) + + class CaosDBClient: def __init__(self, enforce_schema=True, @@ -265,8 +350,8 @@ class CaosDBClient: Returns ------- """ - # TODO - raise Exception("NOT IMPLEMENTED") + logger.debug("update_many(%s, %s)", args, kwargs) + return self._delegate.update_many(*args, **kwargs) def drop_record_type(self, *args, **kwargs): """drop_record_type diff --git a/djaosdb/sql2mongo/converters.py b/djaosdb/sql2mongo/converters.py index d87f0f1..f5c711c 100644 --- a/djaosdb/sql2mongo/converters.py +++ b/djaosdb/sql2mongo/converters.py @@ -61,20 +61,21 @@ class ColumnSelectConverter(Converter): class AggColumnSelectConverter(ColumnSelectConverter): def to_mongo(self): - project = {} + if isinstance(self.sql_tokens[0], CountFuncAll): + return self.sql_tokens[0].to_mongo() if any(isinstance(tok, SQLFunc) for tok in self.sql_tokens): # A SELECT func without groupby clause still needs a groupby # in MongoDB return self._using_group_by() elif isinstance(self.sql_tokens[0], SQLConstIdentifier): + project = {} project[self.sql_tokens[0].alias] = self.sql_tokens[0].to_mongo() else: - for selected in self.sql_tokens: - project[selected.field] = True + project = [selected.field for selected in self.sql_tokens] - return [{'$project': project}] + return {'projection': project} def _using_group_by(self): group = { @@ -126,12 +127,6 @@ class WhereConverter(Converter): return {'filter': self.op.to_mongo()} -class AggWhereConverter(WhereConverter): - - def to_mongo(self): - return {'$match': self.op.to_mongo()} - - class JoinConverter(Converter): @abc.abstractmethod @@ -165,6 +160,25 @@ class JoinConverter(Converter): self.left_column = sql.right_column self.right_column = sql.left_column + def _reference(self): + if self.left_column == "id": + return { + "back_reference": { + "p": self.right_column, + "v": self.right_table, + "negation": False, + "str": + f"IS REFERENCED BY {self.right_table} AS {self.right_column}" + }} + elif self.right_column == "id": + return { "reference": { + "p": self.left_column, + "v": self.right_table, + "negation": False, + "str": f"REFERENCES {self.right_table} AS {self.left_column}" + }} + return None + def _lookup(self): if self.left_table == self.query.left_table: local_field = self.left_column @@ -194,6 +208,11 @@ class InnerJoinConverter(JoinConverter): else: match_field = f'{self.left_table}.{self.left_column}' + # try to construct a references/referenced-by filter + reference = self._reference() + if reference is not None: + return reference + lookup = self._lookup() pipeline = [ { @@ -263,12 +282,6 @@ class LimitConverter(Converter): return {'limit': self.limit} -class AggLimitConverter(LimitConverter): - - def to_mongo(self): - return {'$limit': self.limit} - - class OrderConverter(Converter): def __init__(self, *args): self.columns: List[SQLIdentifier] = [] @@ -315,7 +328,7 @@ class AggOrderConverter(OrderConverter): for tok in self.columns: sort[tok.field] = tok.order - return {'$sort': sort} + return {'sort': sort} class _Tokens2Id: @@ -484,9 +497,3 @@ class OffsetConverter(Converter): def to_mongo(self): return {'skip': self.offset} - - -class AggOffsetConverter(OffsetConverter): - - def to_mongo(self): - return {'$skip': self.offset} diff --git a/djaosdb/sql2mongo/functions.py b/djaosdb/sql2mongo/functions.py index 76c551e..fe1c02c 100644 --- a/djaosdb/sql2mongo/functions.py +++ b/djaosdb/sql2mongo/functions.py @@ -92,7 +92,7 @@ class CountFuncAll(CountFunc): super().__init__(*args) def to_mongo(self): - return {'$sum': 1} + return {'count': self.alias} class CountFuncSingle(CountFunc, SingleParamFunc): diff --git a/djaosdb/sql2mongo/query.py b/djaosdb/sql2mongo/query.py index 3153599..ada223c 100644 --- a/djaosdb/sql2mongo/query.py +++ b/djaosdb/sql2mongo/query.py @@ -16,20 +16,23 @@ from dataclasses import dataclass, field as dataclass_field from caosdb import CaosDBException from sqlparse import parse as sqlparse from sqlparse import tokens -from sqlparse.sql import ( - Identifier, Parenthesis, - Where, - Statement) +from sqlparse.sql import (Identifier, Parenthesis, Where, Statement) +Values = Statement +try: + from sqlparse.sql import Values as Values +except ImportError: + pass # compatibility sqlparse 0.2.4 - 0.3.1 from ..exceptions import SQLDecodeError, MigrationError, print_warn from .functions import SQLFunc from .sql_tokens import (SQLToken, SQLStatement, SQLIdentifier, AliasableToken, SQLConstIdentifier, SQLColumnDef, SQLColumnConstraint) -from .converters import ( - ColumnSelectConverter, AggColumnSelectConverter, FromConverter, WhereConverter, - AggWhereConverter, InnerJoinConverter, OuterJoinConverter, LimitConverter, AggLimitConverter, OrderConverter, - SetConverter, AggOrderConverter, DistinctConverter, NestedInQueryConverter, GroupbyConverter, OffsetConverter, - AggOffsetConverter, HavingConverter) +from .converters import (ColumnSelectConverter, AggColumnSelectConverter, + FromConverter, WhereConverter, InnerJoinConverter, + OuterJoinConverter, LimitConverter, OrderConverter, + SetConverter, DistinctConverter, + NestedInQueryConverter, GroupbyConverter, + OffsetConverter, AggOrderConverter, HavingConverter) from djaosdb import base logger = getLogger(__name__) @@ -197,43 +200,59 @@ class SelectQuery(DQLQuery): return False def _make_pipeline(self): - pipeline = [] - for join in self.joins: - pipeline.extend(join.to_mongo()) + aggregation = {} + joins = None + where = None + if self.joins: + joins = [join.to_mongo() for join in self.joins] if self.nested_query: - pipeline.extend(self.nested_query.to_mongo()) + aggregation["nested"] = self.nested_query.to_mongo() if self.where: - self.where.__class__ = AggWhereConverter - pipeline.append(self.where.to_mongo()) + where = self.where.to_mongo() if self.groupby: - pipeline.extend(self.groupby.to_mongo()) + aggregation["groupby"] = self.groupby.to_mongo() if self.having: - pipeline.append(self.having.to_mongo()) + aggregation["having"] = self.having.to_mongo() if self.distinct: - pipeline.extend(self.distinct.to_mongo()) + aggregation["distinct"] = self.distinct.to_mongo() if self.order: self.order.__class__ = AggOrderConverter - pipeline.append(self.order.to_mongo()) + aggregation.update(self.order.to_mongo()) if self.offset: - self.offset.__class__ = AggOffsetConverter - pipeline.append(self.offset.to_mongo()) + aggregation.update(self.offset.to_mongo()) if self.limit: - self.limit.__class__ = AggLimitConverter - pipeline.append(self.limit.to_mongo()) + aggregation.update(self.limit.to_mongo()) if self._needs_column_selection(): self.selected_columns.__class__ = AggColumnSelectConverter - pipeline.extend(self.selected_columns.to_mongo()) - - return pipeline + select = self.selected_columns.to_mongo() + aggregation.update(select) + + ## merge join and where clause + if joins is not None and where is not None: + if len(joins) == 1 and "reference" in joins[0]: + ref = joins[0]["reference"] + p = ref["p"] + if p in where["filter"]: + ref["sub"] = where["filter"][p] + aggregation.update(joins[0]) + joins = None + where = None + if joins is not None: + aggregation["joins"] = joins + if where is not None: + aggregation.update(where) + + + return aggregation def _needs_column_selection(self): return not(self.distinct or self.groupby) and self.selected_columns @@ -361,7 +380,7 @@ class InsertQuery(DMLQuery): tok = statement.next() self._cols = [token.column for token in SQLToken.tokens2sql(tok[1], self)] - def _fill_values(self, statement: SQLStatement): + def _fill_values(self, statement: U[SQLStatement, Values]): for tok in statement: if isinstance(tok, Parenthesis): placeholder = SQLToken.token2sql(tok, self) @@ -372,8 +391,17 @@ class InsertQuery(DMLQuery): else: values.append(index) self._values.append(values) - elif not tok.match(tokens.Keyword, 'VALUES'): - raise SQLDecodeError + elif tok.match(tokens.Keyword, 'Values'): + # sqlparse==0.2.4 + continue + elif tok.match(tokens.Whitespace, r"\s", True): + # sqlparse==0.3.1 + continue + elif isinstance(tok, Values): + # sqlparse==0.3.1 + self._fill_values(statement=tok) + else: + raise SQLDecodeError(tok) def execute(self): records = [] -- GitLab