Skip to content
Snippets Groups Projects
Verified Commit 4341f5f7 authored by Timm Fitschen's avatar Timm Fitschen
Browse files

Fix sorting with names

parent c113bca8
No related branches found
No related tags found
No related merge requests found
...@@ -36,12 +36,16 @@ class CountResult(Result): ...@@ -36,12 +36,16 @@ class CountResult(Result):
def __iter__(self): def __iter__(self):
return iter(self._results) return iter(self._results)
def _sort_key(x, field): def _sort_key(x, field, names):
return str(x[field] or "") res = x[field] or ""
if res in names:
res = names[res]
return str(res)
class FindResult(Result): class FindResult(Result):
def __init__(self, rows, columns, sort=None, limit=None, skip=None): def __init__(self, rows, columns, sort=None, limit=None, skip=None,
record_types=None, connection=None):
super(FindResult, self).__init__() super(FindResult, self).__init__()
self._sort = sort self._sort = sort
...@@ -52,9 +56,13 @@ class FindResult(Result): ...@@ -52,9 +56,13 @@ class FindResult(Result):
if columns is not None: if columns is not None:
if sort: if sort:
fkey = list(sort)[0] fkey = list(sort)[0]
names = {}
if record_types is not None and fkey in record_types:
names = connection.list_record_names(fkey)
reverse = sort[fkey] < 0 reverse = sort[fkey] < 0
named_columns = [dict(zip(columns, row)) for row in rows] named_columns = [dict(zip(columns, row)) for row in rows]
named_columns.sort(key=lambda x: _sort_key(x, fkey), reverse=reverse) named_columns.sort(key=lambda x: _sort_key(x, fkey, names), reverse=reverse)
self._results = named_columns[skip:upper] self._results = named_columns[skip:upper]
else: else:
self._results = [] self._results = []
...@@ -232,7 +240,8 @@ class DefaultCaosDBClientDelegate: ...@@ -232,7 +240,8 @@ class DefaultCaosDBClientDelegate:
res = self._find(record_type, *args, **kwargs) res = self._find(record_type, *args, **kwargs)
projection = kwargs["projection"] projection = kwargs["projection"]
rows = self._get_property_values(res, projection) rows = self._get_property_values(res, projection)
return FindResult(rows, projection, sort, limit, skip) return FindResult(rows, projection, sort, limit, skip,
record_types=self.cached_record_types, connection=self)
def find_auth(self, record_type, *args, **kwargs): def find_auth(self, record_type, *args, **kwargs):
if record_type == "caosdb_auth_user": if record_type == "caosdb_auth_user":
...@@ -246,6 +255,10 @@ class DefaultCaosDBClientDelegate: ...@@ -246,6 +255,10 @@ class DefaultCaosDBClientDelegate:
return FindResult(rows, columns) return FindResult(rows, columns)
return FindResult([], []) return FindResult([], [])
def list_record_names(self, record_type):
res = self._query(f"SELECT name FROM RECORD '{record_type}'")
return {e.id: e.name for e in res if e.name is not None}
def list_record_type_names(self): def list_record_type_names(self):
res = self._query("SELECT name FROM RECORDTYPE") res = self._query("SELECT name FROM RECORDTYPE")
return [e.name for e in res if e.name is not None] + ["caosdb_auth_user"] return [e.name for e in res if e.name is not None] + ["caosdb_auth_user"]
...@@ -364,7 +377,9 @@ class DefaultCaosDBClientDelegate: ...@@ -364,7 +377,9 @@ class DefaultCaosDBClientDelegate:
if count: if count:
return CountResult(res, count) return CountResult(res, count)
rows = self._get_property_values(res, projection) rows = self._get_property_values(res, projection)
return FindResult(rows, projection, sort, limit, skip) return FindResult(rows, projection, sort, limit, skip,
record_types=self.cached_record_types,
connection=self)
def _generate_query(self, record_type, sort, projection, filter, count): def _generate_query(self, record_type, sort, projection, filter, count):
prefix = 'FIND' prefix = 'FIND'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment