# This file is a part of the CaosDB Project.
#
# Copyright (C) 2023 IndiScale GmbH <info@indiscale.com>
# Copyright (C) 2023 Daniel Hornung <d.hornung@indiscale.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Test the Python-side caching of server queries."""

import os
from datetime import datetime

import caosdb as db
from caosdb.cached import cached_query, cached_get_entity_by, cache_clear


def setup_function():
    d = db.execute_query("FIND Entity WITH ID > 99")
    if len(d) > 0:
        d.delete()
    cache_clear()

    with open("test.dat", "w", encoding="utf-8") as upload_file:
        upload_file.write("hello world\n")


def teardown_function():
    setup_function()

    try:
        os.remove("test.dat")
    except Exception as e:
        print(e)


def test_caching():
    """Test if cached functions work at all."""

    rect = db.RecordType(name="RT1").insert()
    rec = db.Record(name="rec1").add_parent(rect).insert()

    for res in (cached_get_entity_by(eid=rec.id),
                cached_get_entity_by(name=rec.name),
                cached_get_entity_by(query="FIND RECORD rec1"),
                cached_query("FIND RECORD rec1")[0],
                ):
        assert res.id == rec.id

    file_ = db.File(name="TestFile",
                    description="Testfile Desc",
                    path="testfiles/test.dat",
                    file="test.dat").insert()
    res = cached_get_entity_by(path="testfiles/test.dat")
    assert file_.id == res.id


def test_caching_speed():
    """Test if caching is faster that uncached access."""
    rect = db.RecordType(name="RT1").insert()
    db.Record(name="rec1").add_parent(rect).insert()

    # Retrieve once to set up server-side caching.
    db.execute_query("FIND RECORD rec1")

    # uncached
    before = datetime.now()
    cached_query("FIND RECORD rec1")
    after = datetime.now()
    time_uncached = after - before

    # cached
    before = datetime.now()
    cached_query("FIND RECORD rec1")
    after = datetime.now()
    time_cached = after - before

    error_str = f"Cached query was not faster than uncached, {time_uncached} vs. {time_cached}"
    assert time_cached < time_uncached, error_str