diff --git a/djaosdb/caosdb_client.py b/djaosdb/caosdb_client.py index 308941017ee13c7fc7fc8e59c69f1fc59f052571..0c5ce9d7d93adb9fbfd952932499063b19f42dba 100644 --- a/djaosdb/caosdb_client.py +++ b/djaosdb/caosdb_client.py @@ -51,8 +51,9 @@ class FindResult(Result): if columns is not None: if sort: + fkey = list(sort)[0] named_columns = [dict(zip(columns, row)) for row in rows] - named_columns.sort(key=lambda x: _sort_key(x, sort[0][0]), reverse=sort[0][1] < 0) + named_columns.sort(key=lambda x: _sort_key(x, fkey), reverse=sort[fkey] < 0) self._results = named_columns[skip:upper] else: self._results = [] @@ -371,7 +372,7 @@ class DefaultCaosDBClientDelegate: filter_clause += " WITH" + self._get_filter_clause(filter) if count: prefix = "COUNT" - elif sort is not None: + elif sort is not None and projection is not None: prefix = 'SELECT ' + ", ".join(list(projection) + list(sort.keys())) + " FROM" query = f'{prefix} RECORD "{record_type}"{filter_clause}' diff --git a/djaosdb/sql2mongo/converters.py b/djaosdb/sql2mongo/converters.py index 2e937bcab7f2c059d2eda113607a01163179d0ce..7e3387c7c3fb0ac53d27091250871be255102105 100644 --- a/djaosdb/sql2mongo/converters.py +++ b/djaosdb/sql2mongo/converters.py @@ -313,7 +313,7 @@ class OrderConverter(Converter): def to_mongo(self): sort = [(tok.column, tok.order) for tok in self.columns] - return {'sort': sort} + return {'sort': OrderedDict(sort)} class SetConverter(Converter): diff --git a/djaosdb/sql2mongo/functions.py b/djaosdb/sql2mongo/functions.py index 806b361480e47712c7c30db998f5c67e7adf5eca..c8ff8951b2854b5fea2279617088840addbcc9e4 100644 --- a/djaosdb/sql2mongo/functions.py +++ b/djaosdb/sql2mongo/functions.py @@ -22,6 +22,8 @@ class SQLFunc(AliasableToken): func = token[0].get_name() if func == 'COUNT': return CountFunc.token2sql(token, query) + if func == "JSON_CONTAINS": + return JSONContainsFunc(token, query) else: return SimpleFunc(token, query) @@ -37,6 +39,36 @@ class SQLFunc(AliasableToken): def to_mongo(self) -> dict: raise NotImplementedError +class JSONContainsFunc(SQLFunc): + def __init__(self, *args): + super().__init__(*args) + if self.alias: + params = list(self._token[0].get_parameters()) + else: + params = list(self._token.get_parameters()) + self.iden = SQLToken.token2sql(params[0], self.query) + self.val = params[1] + + @property + def left_table(self): + return self.table + + @property + def table(self): + return self.iden.table + + @property + def column(self): + return self.iden.column + + @property + def field(self): + if self.alias: + return self.alias + if self.iden in self.query.token_alias.token2alias: + return self.query.token_alias.token2alias[self.iden] + return self.iden + class SingleParamFunc(SQLFunc): def __init__(self, *args): diff --git a/djaosdb/sql2mongo/operators.py b/djaosdb/sql2mongo/operators.py index 323e772cb8b92fe7d7d0e0946626a8f6c531f0ce..020f5baa4e27da7c4530c6164273a3a117d80cb7 100644 --- a/djaosdb/sql2mongo/operators.py +++ b/djaosdb/sql2mongo/operators.py @@ -466,8 +466,12 @@ class _StatementParser: elif isinstance(tok, Identifier): op = IsTrueOp(tok, self.query) + elif isinstance(tok, Function) and tok.get_name() == "JSON_CONTAINS": + op = CmpOp(tok, self.query) + elif tok.match(tokens.Whitespace, r"\s", True): pass + else: raise SQLDecodeError @@ -564,9 +568,20 @@ class CmpOp(_Op): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._identifier = SQLToken.token2sql(self.statement.left, self.query) self._field_ext = None self._operator = None + if isinstance(self.statement, Function): + params = list(self.statement.get_parameters()) + self._identifier = params[0] + self._operator = OPERATOR_MAP["="] + right = params[1] + index = re_index(right.value) + self._constant = self.params[index] if index is not None else None + self._filter_type = "pov" + self._identifier = SQLToken.token2sql(self.statement, self.query) + return + + self._identifier = SQLToken.token2sql(self.statement.left, self.query) if isinstance(self.statement.right, Identifier): right = SQLToken.token2sql(self.statement.right, self.query) if right.column.lower() == "id": @@ -596,6 +611,8 @@ class CmpOp(_Op): field = self._identifier.field if self._field_ext: ref = self._field_ext + elif self._identifier.query.left_table == self._identifier.table: + field = self._identifier.column elif self._identifier.query.left_table != self._identifier.table: ref = self._identifier.table field = self._identifier.column diff --git a/tests/test_new.py b/tests/test_new.py index 360fe074fe62739488794b00daf0dd95d12381f6..6ac234f36b93ebfb4d94417f35a089aa8f3ae443 100644 --- a/tests/test_new.py +++ b/tests/test_new.py @@ -20,6 +20,13 @@ class _MockContainer(_Container): def insert(self, *args, **kwargs): pass + def execute(self): + return self + + @property + def etag(self): + return "asdf" + class _MockConnection(CaosDBClient): @@ -29,6 +36,9 @@ class _MockConnection(CaosDBClient): if cached_record_types is not None: self.cached_record_types.update(cached_record_types) + def Query(self, *args, **kwargs): + return _MockContainer() + def configure_connection(self, **kwargs): return self @@ -428,7 +438,7 @@ def test_inner_and_outer_join(): def test_query_generation_conjuction(): connection = _MockConnection() - query = connection._delegate._generate_query( + query, key = connection._delegate._generate_query( record_type = 'auth_permission', sort = OrderedDict([('django_content_type.app_label', 1), ('django_content_type.model', 1), @@ -443,8 +453,8 @@ def test_query_generation_conjuction(): count=False) assert query == ('SELECT id, django_content_type.app_label, ' 'django_content_type.model, codename FROM RECORD ' - '"auth_permission" WITH ( auth_group_permissions.' - 'group_id="226" AND ( REFERENCES django_content_type ) )') + '"auth_permission" WITH ( "auth_group_permissions.' + 'group_id"="226" AND ( REFERENCES django_content_type ) )') def test_parse_select_join_with_reverse_on_clause(): sql = """ @@ -942,3 +952,35 @@ def test_not_in(): }], }], } + +def test_JSON_CONTAINS(): + sql = """ + SELECT COUNT(*) AS "__count" + FROM "A" WHERE ( + "A"."col1" = %(0)s + AND JSON_CONTAINS("A"."col2", %(1)s) + )""" + params = ('term1', 'term2') + cached_record_types = [ + "A", + ] + connection = _MockConnection(cached_record_types=cached_record_types) + q = Query(connection=connection, sql=sql, params=params) + caosdb_params = q._query._to_caosdb()[2] + assert caosdb_params["count"] == "__count" + print(caosdb_params["filter"]) + assert caosdb_params["filter"] == { + 'type': 'and', + 'elements': [{ + 'type': 'pov', + 'negation': False, + 'p': 'col1', + 'v': 'term1', + 'o': '=' + }, { + 'type': 'pov', + 'negation': False, + 'p': "col2", + 'v': 'term2', + 'o': '='} + ]}