import asyncio
import gc
import shutil

import pytest

from joblib.memory import (AsyncMemorizedFunc, AsyncNotMemorizedFunc,
                           MemorizedResult, Memory, NotMemorizedResult)
from joblib.test.common import np, with_numpy
from joblib.testing import raises

from .test_memory import (corrupt_single_cache_item,
                          monkeypatch_cached_func_warn)


async def check_identity_lazy_async(func, accumulator, location):
    """ Similar to check_identity_lazy_async for coroutine functions"""
    memory = Memory(location=location, verbose=0)
    func = memory.cache(func)
    for i in range(3):
        for _ in range(2):
            value = await func(i)
            assert value == i
            assert len(accumulator) == i + 1


@pytest.mark.asyncio
async def test_memory_integration_async(tmpdir):
    accumulator = list()

    async def f(n):
        await asyncio.sleep(0.1)
        accumulator.append(1)
        return n

    await check_identity_lazy_async(f, accumulator, tmpdir.strpath)

    # Now test clearing
    for compress in (False, True):
        for mmap_mode in ('r', None):
            memory = Memory(location=tmpdir.strpath, verbose=10,
                            mmap_mode=mmap_mode, compress=compress)
            # First clear the cache directory, to check that our code can
            # handle that
            # NOTE: this line would raise an exception, as the database
            # file is still open; we ignore the error since we want to
            # test what happens if the directory disappears
            shutil.rmtree(tmpdir.strpath, ignore_errors=True)
            g = memory.cache(f)
            await g(1)
            g.clear(warn=False)
            current_accumulator = len(accumulator)
            out = await g(1)

        assert len(accumulator) == current_accumulator + 1
        # Also, check that Memory.eval works similarly
        evaled = await memory.eval(f, 1)
        assert evaled == out
        assert len(accumulator) == current_accumulator + 1

    # Now do a smoke test with a function defined in __main__, as the name
    # mangling rules are more complex
    f.__module__ = '__main__'
    memory = Memory(location=tmpdir.strpath, verbose=0)
    await memory.cache(f)(1)


@pytest.mark.asyncio
async def test_no_memory_async():
    accumulator = list()

    async def ff(x):
        await asyncio.sleep(0.1)
        accumulator.append(1)
        return x

    memory = Memory(location=None, verbose=0)
    gg = memory.cache(ff)
    for _ in range(4):
        current_accumulator = len(accumulator)
        await gg(1)
        assert len(accumulator) == current_accumulator + 1


@with_numpy
@pytest.mark.asyncio
async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch):
    """Check that mmap_mode is respected even at the first call"""

    memory = Memory(location=tmpdir.strpath, mmap_mode='r', verbose=0)

    @memory.cache()
    async def twice(a):
        return a * 2

    a = np.ones(3)
    b = await twice(a)
    c = await twice(a)

    assert isinstance(c, np.memmap)
    assert c.mode == 'r'

    assert isinstance(b, np.memmap)
    assert b.mode == 'r'

    # Corrupts the file,  Deleting b and c mmaps
    # is necessary to be able edit the file
    del b
    del c
    gc.collect()
    corrupt_single_cache_item(memory)

    # Make sure that corrupting the file causes recomputation and that
    # a warning is issued.
    recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
    d = await twice(a)
    assert len(recorded_warnings) == 1
    exception_msg = 'Exception while loading results'
    assert exception_msg in recorded_warnings[0]
    # Asserts that the recomputation returns a mmap
    assert isinstance(d, np.memmap)
    assert d.mode == 'r'


@pytest.mark.asyncio
async def test_call_and_shelve_async(tmpdir):
    async def f(x, y=1):
        await asyncio.sleep(0.1)
        return x ** 2 + y

    # Test MemorizedFunc outputting a reference to cache.
    for func, Result in zip((AsyncMemorizedFunc(f, tmpdir.strpath),
                             AsyncNotMemorizedFunc(f),
                             Memory(location=tmpdir.strpath,
                                    verbose=0).cache(f),
                             Memory(location=None).cache(f),
                             ),
                            (MemorizedResult, NotMemorizedResult,
                             MemorizedResult, NotMemorizedResult,
                             )):
        for _ in range(2):
            result = await func.call_and_shelve(2)
            assert isinstance(result, Result)
            assert result.get() == 5

        result.clear()
        with raises(KeyError):
            result.get()
        result.clear()  # Do nothing if there is no cache.


@pytest.mark.asyncio
async def test_memorized_func_call_async(memory):

    async def ff(x, counter):
        await asyncio.sleep(0.1)
        counter[x] = counter.get(x, 0) + 1
        return counter[x]

    gg = memory.cache(ff, ignore=['counter'])

    counter = {}
    assert await gg(2, counter) == 1
    assert await gg(2, counter) == 1

    x, meta = await gg.call(2, counter)
    assert x == 2, "f has not been called properly"
    assert isinstance(meta, dict), (
        "Metadata are not returned by MemorizedFunc.call."
    )
