from __future__ import annotations

import multiprocessing
import threading
import weakref
from typing import Any, MutableMapping

try:
    from dask.utils import SerializableLock
except ImportError:
    # no need to worry about serializing the lock
    SerializableLock = threading.Lock  # type: ignore


# Locks used by multiple backends.
# Neither HDF5 nor the netCDF-C library are thread-safe.
HDF5_LOCK = SerializableLock()
NETCDFC_LOCK = SerializableLock()


_FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary()


def _get_threaded_lock(key):
    try:
        lock = _FILE_LOCKS[key]
    except KeyError:
        lock = _FILE_LOCKS[key] = threading.Lock()
    return lock


def _get_multiprocessing_lock(key):
    # TODO: make use of the key -- maybe use locket.py?
    # https://github.com/mwilliamson/locket.py
    del key  # unused
    return multiprocessing.Lock()


def _get_lock_maker(scheduler=None):
    """Returns an appropriate function for creating resource locks.

    Parameters
    ----------
    scheduler : str or None
        Dask scheduler being used.

    See Also
    --------
    dask.utils.get_scheduler_lock
    """

    if scheduler is None:
        return _get_threaded_lock
    elif scheduler == "threaded":
        return _get_threaded_lock
    elif scheduler == "multiprocessing":
        return _get_multiprocessing_lock
    elif scheduler == "distributed":
        # Lazy import distributed since it is can add a significant
        # amount of time to import
        try:
            from dask.distributed import Lock as DistributedLock
        except ImportError:
            DistributedLock = None
        return DistributedLock
    else:
        raise KeyError(scheduler)


def _get_scheduler(get=None, collection=None) -> str | None:
    """Determine the dask scheduler that is being used.

    None is returned if no dask scheduler is active.

    See Also
    --------
    dask.base.get_scheduler
    """
    try:
        # Fix for bug caused by dask installation that doesn't involve the toolz library
        # Issue: 4164
        import dask
        from dask.base import get_scheduler  # noqa: F401

        actual_get = get_scheduler(get, collection)
    except ImportError:
        return None

    try:
        from dask.distributed import Client

        if isinstance(actual_get.__self__, Client):
            return "distributed"
    except (ImportError, AttributeError):
        pass

    try:
        # As of dask=2.6, dask.multiprocessing requires cloudpickle to be installed
        # Dependency removed in https://github.com/dask/dask/pull/5511
        if actual_get is dask.multiprocessing.get:
            return "multiprocessing"
    except AttributeError:
        pass

    return "threaded"


def get_write_lock(key):
    """Get a scheduler appropriate lock for writing to the given resource.

    Parameters
    ----------
    key : str
        Name of the resource for which to acquire a lock. Typically a filename.

    Returns
    -------
    Lock object that can be used like a threading.Lock object.
    """
    scheduler = _get_scheduler()
    lock_maker = _get_lock_maker(scheduler)
    return lock_maker(key)


def acquire(lock, blocking=True):
    """Acquire a lock, possibly in a non-blocking fashion.

    Includes backwards compatibility hacks for old versions of Python, dask
    and dask-distributed.
    """
    if blocking:
        # no arguments needed
        return lock.acquire()
    else:
        # "blocking" keyword argument not supported for:
        # - threading.Lock on Python 2.
        # - dask.SerializableLock with dask v1.0.0 or earlier.
        # - multiprocessing.Lock calls the argument "block" instead.
        # - dask.distributed.Lock uses the blocking argument as the first one
        return lock.acquire(blocking)


class CombinedLock:
    """A combination of multiple locks.

    Like a locked door, a CombinedLock is locked if any of its constituent
    locks are locked.
    """

    def __init__(self, locks):
        self.locks = tuple(set(locks))  # remove duplicates

    def acquire(self, blocking=True):
        return all(acquire(lock, blocking=blocking) for lock in self.locks)

    def release(self):
        for lock in self.locks:
            lock.release()

    def __enter__(self):
        for lock in self.locks:
            lock.__enter__()

    def __exit__(self, *args):
        for lock in self.locks:
            lock.__exit__(*args)

    def locked(self):
        return any(lock.locked for lock in self.locks)

    def __repr__(self):
        return f"CombinedLock({list(self.locks)!r})"


class DummyLock:
    """DummyLock provides the lock API without any actual locking."""

    def acquire(self, blocking=True):
        pass

    def release(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self, *args):
        pass

    def locked(self):
        return False


def combine_locks(locks):
    """Combine a sequence of locks into a single lock."""
    all_locks = []
    for lock in locks:
        if isinstance(lock, CombinedLock):
            all_locks.extend(lock.locks)
        elif lock is not None:
            all_locks.append(lock)

    num_locks = len(all_locks)
    if num_locks > 1:
        return CombinedLock(all_locks)
    elif num_locks == 1:
        return all_locks[0]
    else:
        return DummyLock()


def ensure_lock(lock):
    """Ensure that the given object is a lock."""
    if lock is None or lock is False:
        return DummyLock()
    return lock
