/*
 * Copyright (c) 2017,  Boradcom Ltd. All rights reserved.
 * Copyright (c) 2016,  Mellanox Technologies. All rights reserved.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include <linux/mm.h>
#include <linux/dma-mapping.h>
#include <linux/sched.h>
#include <linux/export.h>
#include <linux/hugetlb.h>
#include <linux/slab.h>
#ifdef HAVE_LINUX_SCHED_MM_H
#include <linux/sched/mm.h>
#endif
#ifdef HAVE_LINUX_SCHED__SIGNAL_H
#include <linux/sched/signal.h>
#endif

#include <rdma/ib_umem_odp.h>
#include <rdma/ib_umem.h>
#include <rdma/ib_verbs.h>

#include "peer_compat.h"
#include "peer_mem.h"
#include "peer_umem.h"

#define get_key(umem)  ((umem & 0xFFFFFFFF) >> 16)

static inline struct ib_peer_mem_tbl_entry *get_peer_mem_tbl_entry
				(struct ib_umem *umem)
{
	struct ib_peer_mem_tbl_entry *tbl_entry = NULL;
#ifdef HAVE_IB_DEVICE_IN_UMEM
	struct ib_device *device = umem->ibdev;
#else
	struct ib_device *device = umem->context->device;
#endif
	struct ib_peer_mem_device *peer_dev = NULL;
	bool found = false;
	u32 key;

	peer_dev = get_peer_mem_device(device);
	if (peer_dev) {
		key = get_key((unsigned long)umem);
		mutex_lock(&peer_dev->hash_lock);
		hash_for_each_possible
			(peer_dev->peer_mem_hash, tbl_entry, entry, key) {
			if (tbl_entry->umem == umem) {
				found = true;
				mutex_unlock(&peer_dev->hash_lock);
				break;
			}
		}
		if (found)
			return tbl_entry;
		mutex_unlock(&peer_dev->hash_lock);
	}
	return NULL;
}

struct ib_peer_umem *ib_peer_mem_get_data(struct ib_umem *umem)
{
	struct ib_peer_mem_tbl_entry *tbl_entry = NULL;
#ifdef HAVE_IB_DEVICE_IN_UMEM
	struct ib_device *device = umem->ibdev;
#else
	struct ib_device *device = umem->context->device;
#endif
	struct ib_peer_mem_device *peer_dev = NULL;
	bool found = false;
	u32 key;

	peer_dev = get_peer_mem_device(device);
	if (peer_dev) {
		key = get_key((unsigned long)umem);
		mutex_lock(&peer_dev->hash_lock);
		hash_for_each_possible
			(peer_dev->peer_mem_hash, tbl_entry, entry, key) {
			if (tbl_entry->umem == umem) {
				found = true;
				mutex_unlock(&peer_dev->hash_lock);
				break;
			}
		}
		if (found)
			return &tbl_entry->peer_umem;
		mutex_unlock(&peer_dev->hash_lock);
	}
	return NULL;
}
EXPORT_SYMBOL(ib_peer_mem_get_data);

static struct ib_umem *peer_umem_get(struct ib_device *ibdev,
				     struct ib_peer_memory_client *ib_peer_mem,
				     struct ib_peer_mem_tbl_entry *tbl_entry,
				     struct peer_mem_umem *p_umem,
				     unsigned long addr,
				     unsigned long flags)
{
	int ret;
	const struct peer_memory_client *peer_mem = ib_peer_mem->peer_mem;
	struct invalidation_ctx *ictx = NULL;
	struct ib_umem *umem;
	struct sg_table sg_head;
	struct ib_peer_umem *peer_umem = &tbl_entry->peer_umem;

	dev_dbg(NULL, "%s: peer_umem = %p ib_peer_mem = %p\n",
		__func__, peer_umem, ib_peer_mem);
	peer_umem->ib_peer_mem = ib_peer_mem;
	if (flags & IB_UMEM_PEER_INVAL_SUPP) {
		ret = ib_peer_create_invalidation_ctx(ib_peer_mem, p_umem,
						      &ictx);
		if (ret)
			goto end;
		peer_umem->invalidation_ctx = ictx;
		ictx->peer_umem = peer_umem;
		dev_dbg(NULL, "%s: peer_umem->invalidation_ctx = %p\n",
			__func__, peer_umem->invalidation_ctx);
	}

	/*
	 * We always request write permissions to the pages, to force breaking
	 * of any CoW during the registration of the MR. For read-only MRs we
	 * use the "force" flag to indicate that CoW breaking is required but
	 * the registration should not fail if referencing read-only areas.
	 */
	ret = peer_mem->get_pages(addr, p_umem->length,
				  1, !p_umem->writable,
				  &sg_head,
				  peer_umem->peer_mem_client_context,
				  ictx ? ictx->context_ticket : 0);
	if (ret)
		goto out;

	/* Found a valid peer umem. So allocate a umem here */
	umem = kzalloc(sizeof(*umem), GFP_KERNEL);
	if (!umem) {
		ret = -ENOMEM;
		goto put_pages;
	}

#ifdef HAVE_IB_DEVICE_IN_UMEM
	umem->ibdev = ibdev;
#else
	umem->context   = p_umem->context;
#endif
	umem->length    = p_umem->length;
	umem->address   = p_umem->address;
	umem->writable  = p_umem->writable;
#ifdef HAVE_IB_UMEM_SG_APPEND_TABLE
	umem->sgt_append.sgt = sg_head;
#else
	umem->sg_head	= sg_head;
#endif

	tbl_entry->page_shift = ilog2(peer_mem->get_page_size
				      (peer_umem->peer_mem_client_context));

#ifdef HAVE_IB_UMEM_PAGE_SIZE
	umem->page_size = peer_mem->get_page_size
			(peer_umem->peer_mem_client_context);
#endif
#ifdef HAVE_IB_UMEM_PAGE_SHIFT
	umem->page_shift = ilog2(peer_mem->get_page_size
			(peer_umem->peer_mem_client_context));
#endif

#ifdef HAVE_IB_UMEM_SG_APPEND_TABLE
	ret = peer_mem->dma_map(&umem->sgt_append.sgt,
#else
	ret = peer_mem->dma_map(&umem->sg_head,
#endif
				peer_umem->peer_mem_client_context,
#ifdef HAVE_IB_DEVICE_IN_UMEM
				umem->ibdev->dma_device,
#else
				umem->context->device->dma_device,
#endif
				flags & IB_UMEM_DMA_SYNC,
#ifdef HAVE_IB_UMEM_SG_APPEND_TABLE
				&umem->sgt_append.sgt.nents);
#else
				&umem->nmap);
#endif
	if (ret)
		goto free_umem;

	if (ictx)
		ictx->umem = umem;
	return umem;

free_umem:
	kfree(umem);

put_pages:
	peer_mem->put_pages(&sg_head,
			    peer_umem->peer_mem_client_context
);
out:
	if (ictx)
		ib_peer_destroy_invalidation_ctx(ib_peer_mem, ictx);
end:
	ib_put_peer_client(ib_peer_mem, peer_umem->peer_mem_client_context
);
	return ERR_PTR(ret);
}

static void peer_umem_release(struct ib_umem *umem)
{
	struct ib_peer_memory_client *ib_peer_mem = NULL;
	const struct peer_memory_client *peer_mem = NULL;
	struct invalidation_ctx *ictx = NULL;
	struct ib_peer_umem *peer_umem = NULL;
	struct sg_table *sg_head;

	peer_umem = ib_peer_mem_get_data(umem);
	if (!peer_umem) {
		WARN_ON(1);
		return;
	}

	ib_peer_mem = peer_umem->ib_peer_mem;
	peer_mem = ib_peer_mem->peer_mem;

	ictx = peer_umem->invalidation_ctx;

	if (ictx)
		ib_peer_destroy_invalidation_ctx(ib_peer_mem, ictx);
#ifdef HAVE_IB_UMEM_SG_APPEND_TABLE
	sg_head = &umem->sgt_append.sgt;
#else
	sg_head = &umem->sg_head;
#endif

	peer_mem->dma_unmap(sg_head,
			    peer_umem->peer_mem_client_context,
#ifdef HAVE_IB_DEVICE_IN_UMEM
				umem->ibdev->dma_device
#else
				umem->context->device->dma_device
#endif
			);
	peer_mem->put_pages(sg_head,
			    peer_umem->peer_mem_client_context
);
	ib_put_peer_client(ib_peer_mem,
			   peer_umem->peer_mem_client_context
);
	kfree(umem);
}

int ib_umem_activate_invalidation_notifier(struct ib_umem *umem,
					   void (*func)(void *cookie,
					   struct ib_umem *umem,
					   unsigned long addr, size_t size),
					   void *cookie)
{
	struct invalidation_ctx *ictx = NULL;
	int ret = 0;
	struct ib_peer_umem *peer_umem = NULL;

	peer_umem = ib_peer_mem_get_data(umem);
	if (!peer_umem) {
		WARN_ON(1);
		return -ENOMEM;
	}
	if (!peer_umem->ib_peer_mem)
		return 0;

	ictx = peer_umem->invalidation_ctx;

	dev_dbg(NULL, "%s: peer_umem = %p\n", __func__, peer_umem);
	dev_dbg(NULL, "%s: peer_umem->ib_peer_mem = %p\n", __func__,
		peer_umem->ib_peer_mem);
	mutex_lock(&peer_umem->ib_peer_mem->lock);
	if (ictx->peer_invalidated) {
		pr_err("ib_umem_activate_invalidation_notifier: pages were invalidated by peer\n");
		ret = -EINVAL;
		goto end;
	}
	ictx->func = func;
	ictx->cookie = cookie;
	/* from that point any pending invalidations can be called */
end:
	mutex_unlock(&peer_umem->ib_peer_mem->lock);
	return ret;
}
EXPORT_SYMBOL(ib_umem_activate_invalidation_notifier);

/*
 * Get the mapping of the peer memory
 * Returns - umem, if a valid peer memory is identified
 *	   - Error, if any error during peer memory mapping
 *	   - NULL, if peer memory mapping is not found
 */
static struct ib_umem *ib_peer_mem_umem_get(struct ib_device *ibdev,
					    struct ib_ucontext *context,
					    unsigned long addr,
					    size_t size, int access,
					    unsigned long flags)
{
	/* Peer mem declarations */
	struct ib_peer_mem_tbl_entry *tbl_entry = NULL;
	struct ib_peer_mem_device *peer_dev = NULL;
	struct peer_mem_umem p_umem = {};
	u32 key;

	tbl_entry = (struct ib_peer_mem_tbl_entry *)
			kzalloc(sizeof(*tbl_entry), GFP_KERNEL);
	if (!tbl_entry)
		return ERR_PTR(-ENOMEM);

	p_umem.context   = context;
	p_umem.length    = size;
	p_umem.address   = addr;
	p_umem.page_shift = PAGE_SHIFT;

	p_umem.writable  = ib_access_writable_compat(access);

	if (flags & IB_UMEM_PEER_ALLOW) {
		struct ib_peer_memory_client *peer_mem_client;
		struct ib_umem *peer_umem;

		peer_mem_client =
			ib_get_peer_client
				(context, addr, size, flags,
				 &tbl_entry->peer_umem.peer_mem_client_context
				);
		if (IS_ERR(peer_mem_client)) {
			kfree(tbl_entry);
			return ERR_CAST(peer_mem_client);

		} else if (peer_mem_client) {
			peer_umem = peer_umem_get(ibdev, peer_mem_client,
						  tbl_entry, &p_umem,
						  addr, flags);
			if (IS_ERR(peer_umem)) {
				kfree(tbl_entry);
				return ERR_CAST(peer_umem);
			}

			key = get_key((unsigned long)peer_umem);
			peer_dev = get_peer_mem_device(context->device);
			if (peer_dev) {
				mutex_lock(&peer_dev->hash_lock);
				tbl_entry->peer_dev = peer_dev;
				hash_add(peer_dev->peer_mem_hash,
					 &tbl_entry->entry, key);
				mutex_unlock(&peer_dev->hash_lock);
			}

			dev_dbg(NULL, "%s: peer_dev = %p tbl_entry = %p umem = %p\n",
				__func__, peer_dev, tbl_entry, peer_umem);
			tbl_entry->umem = peer_umem;

			return peer_umem;
		}
	}

	/*
	 * The peer_mem check didn't provide a valid memory
	 * Return NULL so that the caller can check for
	 * host memory mapping.
	 */
	kfree(tbl_entry);
	return NULL;
}

/**
 * ib_umem_get_flags - Pin and DMA map userspace memory.
 *
 * If access flags indicate ODP memory, avoid pinning. Instead, stores
 * the mm for future page fault handling in conjunction with MMU notifiers.
 *
 * @context: userspace context to pin memory for
 * @addr: userspace virtual address to start at
 * @size: length of region to pin
 * @access: IB_ACCESS_xxx flags for memory being pinned
 * @flags: IB_UMEM_xxx flags for memory being used
 */
struct ib_umem *ib_umem_get_flags(struct ib_device *ibdev,
				  struct ib_ucontext *context,
				  struct ib_udata *udata,
				  unsigned long addr,
				  size_t size, int access,
				  unsigned long flags)
{
	struct ib_umem *umem;

	if (!size)
		return ERR_PTR(-EINVAL);

	/*
	 * If the combination of the addr and size requested for this memory
	 * region causes an integer overflow, return error.
	 */
	if (((addr + size) < addr) ||
	    PAGE_ALIGN(addr + size) < (addr + size))
		return ERR_PTR(-EINVAL);

	if (!can_do_mlock())
		return ERR_PTR(-EPERM);

	/* Map host memory */
	umem = __ib_umem_get_compat(ibdev, context,
				    udata, addr, size,
				    access, flags);

	/* If host umem return failure check if its peer mem */
	if (PTR_ERR(umem) == -EFAULT)
		goto peer_mem;

	/* Return if umem is valid or on any error */
	if (IS_ERR(umem) || umem)
		return umem;

peer_mem:
	return ib_peer_mem_umem_get(ibdev, context, addr, size, access, flags);

}
EXPORT_SYMBOL(ib_umem_get_flags);

/**
 * ib_umem_release_flags - release memory pinned with ib_umem_get_flags
 * @umem: umem struct to release
 */
void ib_umem_release_flags(struct ib_umem *umem)
{
	struct ib_peer_mem_tbl_entry *tbl_entry = NULL;
	struct ib_peer_mem_device *peer_dev = NULL;

	if ((tbl_entry = get_peer_mem_tbl_entry(umem)) != NULL) {
		peer_dev = tbl_entry->peer_dev;
		BUG_ON(!peer_dev);
		peer_umem_release(umem);
		dev_dbg(NULL, "%s:umem = %p tbl_entry = %p\n", __func__, umem, tbl_entry);
		mutex_lock(&peer_dev->hash_lock);
		hash_del(&tbl_entry->entry);
		mutex_unlock(&peer_dev->hash_lock);
		kfree(tbl_entry);
		return;
	}


	ib_umem_release(umem);

}
EXPORT_SYMBOL(ib_umem_release_flags);

/**
 * ib_umem_get_peer_page_shift - get the page_shift of the umem
 * @umem: umem to get page_size
 */
int ib_umem_get_peer_page_shift(struct ib_umem *umem)
{
	struct ib_peer_mem_tbl_entry *tbl_entry = NULL;

	tbl_entry = get_peer_mem_tbl_entry(umem);
	if (tbl_entry)
		return tbl_entry->page_shift;
	return -EINVAL;
}
EXPORT_SYMBOL(ib_umem_get_peer_page_shift);

void peer_mem_init_hash_tbl(struct ib_peer_mem_device *peer_dev)
{
	hash_init(peer_dev->peer_mem_hash);
}
