/*
 * Copyright (c) 2017,  Boradcom Ltd. All rights reserved.
 * Copyright (c) 2016,  Mellanox Technologies. All rights reserved.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include <linux/module.h>
#include "peer_mem.h"
#include "peer_umem.h"
#include "peer_compat.h"
#include <rdma/ib_verbs.h>
#include <rdma/ib_umem.h>

MODULE_LICENSE("Dual BSD/GPL");
MODULE_DESCRIPTION("IB peer memory module");
MODULE_VERSION(IB_PEER_MEM_MODULE_VERSION);

#ifndef HAVE_MM_KOBJ_EXPORT
struct kobject *mm_kobj;
#endif

struct list_head peer_dev_list = LIST_HEAD_INIT(peer_dev_list);
static DEFINE_MUTEX(peer_dev_lock);

static DEFINE_MUTEX(peer_memory_mutex);
static LIST_HEAD(peer_memory_list);

struct ib_peer_mem_device *get_peer_mem_device(struct ib_device *device)
{
	struct ib_peer_mem_device *dev;

	if (!device)
		return NULL;

	mutex_lock(&peer_dev_lock);
	list_for_each_entry(dev, &peer_dev_list, list) {
		if (dev->device == device) {
			mutex_unlock(&peer_dev_lock);
			return dev;
		}
	}
	mutex_unlock(&peer_dev_lock);
	return NULL;
}

static int ib_invalidate_peer_memory(void *reg_handle, u64 core_context)
{
	struct ib_peer_memory_client *ib_peer_client = reg_handle;
	struct invalidation_ctx *invalidation_ctx;
	struct core_ticket *core_ticket;

	mutex_lock(&ib_peer_client->lock);
	core_ticket = (struct core_ticket *)core_context;
	if (!core_ticket) {
		mutex_unlock(&ib_peer_client->lock);
		return 0;
	}

	invalidation_ctx = (struct invalidation_ctx *)core_ticket->context;
	/* If context is not ready yet, mark it to be invalidated */
	if (!invalidation_ctx->func) {
		invalidation_ctx->peer_invalidated = 1;
		mutex_unlock(&ib_peer_client->lock);
		return 0;
	}
	invalidation_ctx->func(invalidation_ctx->cookie,
					invalidation_ctx->umem, 0, 0);
	if (invalidation_ctx->inflight_invalidation) {
		/* init the completion to wait on
		 * before letting other thread to run
		 */
		init_completion(&invalidation_ctx->comp);
		mutex_unlock(&ib_peer_client->lock);
		wait_for_completion(&invalidation_ctx->comp);
	} else {
		mutex_unlock(&ib_peer_client->lock);
	}

	return 0;
}

static int ib_peer_insert_context(struct ib_peer_memory_client *ib_peer_client,
				  void *context,
				  u64 *context_ticket)
{
	struct core_ticket *core_ticket;

	core_ticket = kzalloc(sizeof(*core_ticket), GFP_KERNEL);
	if (!core_ticket)
		return -ENOMEM;

	mutex_lock(&ib_peer_client->lock);
	core_ticket->key = (unsigned long)core_ticket;
	core_ticket->context = context;
	list_add_tail(&core_ticket->ticket_list,
		      &ib_peer_client->core_ticket_list);
	*context_ticket = core_ticket->key;
	mutex_unlock(&ib_peer_client->lock);

	return 0;
}

/*
 * ib_peer_create_invalidation_ctx - creates invalidation context for given umem
 * @ib_peer_mem: peer client to be used
 * @umem: umem struct belongs to that context
 * @invalidation_ctx: output context
 */
int ib_peer_create_invalidation_ctx(struct ib_peer_memory_client *ib_peer_mem,
				    struct peer_mem_umem *p_umem,
				    struct invalidation_ctx **invalidation_ctx)
{
	int ret;
	struct invalidation_ctx *ctx;

	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
	if (!ctx)
		return -ENOMEM;

	ret = ib_peer_insert_context(ib_peer_mem, ctx,
				     &ctx->context_ticket);
	if (ret) {
		kfree(ctx);
		return ret;
	}

	*invalidation_ctx = ctx;
	return 0;
}

/**
 * ** ib_peer_destroy_invalidation_ctx - destroy a given invalidation context
 * ** @ib_peer_mem: peer client to be used
 * ** @invalidation_ctx: context to be invalidated
 * **/
void ib_peer_destroy_invalidation_ctx(struct ib_peer_memory_client *ib_peer_mem,
				      struct invalidation_ctx *invalidation_ctx)
{
	struct core_ticket *core_ticket;
	struct ib_peer_umem *peer_umem;
	int inflight_invalidation;

	mutex_lock(&ib_peer_mem->lock);
	core_ticket = (struct core_ticket *)invalidation_ctx->context_ticket;
	if (core_ticket) {
		list_del(&core_ticket->ticket_list);
		kfree(core_ticket);
	}
	/* Make sure to check inflight flag after took the lock and remove
	 * from tree. In addition, from that point using local variables for
	 * inflight_invalidation as after the complete invalidation_ctx can't
	 * be accessed any more as it may be freed by the callback.
	 */
	inflight_invalidation = invalidation_ctx->inflight_invalidation;
	if (inflight_invalidation)
		complete(&invalidation_ctx->comp);
	mutex_unlock(&ib_peer_mem->lock);

	/* In case under callback context or callback is pending
	 * let it free the invalidation context
	 */
	peer_umem = invalidation_ctx->peer_umem;
	dev_dbg(NULL, "%s: Freeing invalidation context  = %p\n",
		__func__, invalidation_ctx);
	kfree(invalidation_ctx);
	peer_umem->invalidation_ctx = NULL;
}

static void complete_peer(struct kref *kref)
{
	struct ib_peer_memory_client *ib_peer_client =
		container_of(kref, struct ib_peer_memory_client, ref);

	complete(&ib_peer_client->unload_comp);
}



static struct kobject *peers_kobj;

static struct ib_peer_memory_client *get_peer_by_kobj(void *kobj)
{
        struct ib_peer_memory_client *ib_peer_client;

        mutex_lock(&peer_memory_mutex);
        list_for_each_entry(ib_peer_client, &peer_memory_list, core_peer_list) {
                if (ib_peer_client->kobj == kobj) {
                        kref_get(&ib_peer_client->ref);
                        goto found;
                }
        }

        ib_peer_client = NULL;
found:
        mutex_unlock(&peer_memory_mutex);
        return ib_peer_client;
}


static ssize_t version_show(struct kobject *kobj,
                            struct kobj_attribute *attr, char *buf)
{
        struct ib_peer_memory_client *ib_peer_client = get_peer_by_kobj(kobj);

        if (ib_peer_client) {
                sprintf(buf, "%s\n", ib_peer_client->peer_mem->version);
                kref_put(&ib_peer_client->ref, complete_peer);
                return strlen(buf);
        }
        /* not found - nothing is return */
        return 0;
}

static struct kobj_attribute version_attr = __ATTR_RO(version);

static struct attribute *peer_mem_attrs[] = {
                        &version_attr.attr,
                        NULL,
};


static void destroy_peer_sysfs(struct ib_peer_memory_client *ib_peer_client)
{
        kobject_put(ib_peer_client->kobj);
        if (list_empty(&peer_memory_list))
                kobject_put(peers_kobj);
}

static int create_peer_sysfs(struct ib_peer_memory_client *ib_peer_client)
{
        int ret;

        if (list_empty(&peer_memory_list)) {
                /* creating under /sys/kernel/mm */
                peers_kobj = kobject_create_and_add("memory_peers", mm_kobj);
                if (!peers_kobj)
                        return -ENOMEM;
        }

        ib_peer_client->peer_mem_attr_group.attrs = peer_mem_attrs;
        /* Dir alreday was created explicitly to get its kernel object for further usage */
        ib_peer_client->peer_mem_attr_group.name =  NULL;
        ib_peer_client->kobj = kobject_create_and_add(ib_peer_client->peer_mem->name,
                peers_kobj);

        if (!ib_peer_client->kobj) {
                ret = -EINVAL;
                goto free;
        }

        /* Create the files associated with this kobject */
        ret = sysfs_create_group(ib_peer_client->kobj,
                                 &ib_peer_client->peer_mem_attr_group);
        if (ret)
                goto peer_free;

        return 0;

peer_free:
        kobject_put(ib_peer_client->kobj);

free:
        if (list_empty(&peer_memory_list))
                kobject_put(peers_kobj);

        return ret;
}


void *ib_register_peer_memory_client(struct peer_memory_client *peer_client,
				     int (**invalidate_callback)
				     (void *reg_handle, u64 core_context))
{
	struct ib_peer_memory_client *ib_peer_client;

	ib_peer_client = kzalloc(sizeof(*ib_peer_client), GFP_KERNEL);
	if (!ib_peer_client)
		return NULL;

	INIT_LIST_HEAD(&ib_peer_client->core_ticket_list);
	mutex_init(&ib_peer_client->lock);
	init_completion(&ib_peer_client->unload_comp);
	kref_init(&ib_peer_client->ref);
	ib_peer_client->peer_mem = peer_client;
	ib_peer_client->last_ticket = 1;
	/* Once peer supplied a non NULL callback it's an indication that
	 * invalidation support is required for any memory owning.
	 */
	if (invalidate_callback) {
		*invalidate_callback = ib_invalidate_peer_memory;
		ib_peer_client->invalidation_required = 1;
	}

	mutex_lock(&peer_memory_mutex);
	if (create_peer_sysfs(ib_peer_client)) {
		kfree(ib_peer_client);
		ib_peer_client = NULL;
		goto end;
	}
	list_add_tail(&ib_peer_client->core_peer_list, &peer_memory_list);
end:
	mutex_unlock(&peer_memory_mutex);

	return ib_peer_client;
}

#ifdef USE_NVIDIA_GPU
EXPORT_SYMBOL(ib_register_peer_memory_client);
#else
EXPORT_SYMBOL_GPL(ib_register_peer_memory_client);
#endif

void ib_unregister_peer_memory_client(void *reg_handle)
{
	struct ib_peer_memory_client *ib_peer_client = reg_handle;

	mutex_lock(&peer_memory_mutex);
	list_del(&ib_peer_client->core_peer_list);
	destroy_peer_sysfs(ib_peer_client);
	mutex_unlock(&peer_memory_mutex);

	kref_put(&ib_peer_client->ref, complete_peer);
	wait_for_completion(&ib_peer_client->unload_comp);
	kfree(ib_peer_client);
}

#ifdef USE_NVIDIA_GPU
EXPORT_SYMBOL(ib_unregister_peer_memory_client);
#else
EXPORT_SYMBOL_GPL(ib_unregister_peer_memory_client);
#endif

struct ib_peer_memory_client *ib_get_peer_client(struct ib_ucontext *context,
						 unsigned long addr,
						 size_t size,
						 unsigned long flags,
						 void **peer_client_context)
{
	struct ib_peer_memory_client *ib_peer_client;
	int ret;

	mutex_lock(&peer_memory_mutex);
	list_for_each_entry(ib_peer_client, &peer_memory_list, core_peer_list) {
		/* In case peer requires invalidation it can't own memory
		 * which doesn't support it
		 */
		if (ib_peer_client->invalidation_required &&
		    (!(flags & IB_UMEM_PEER_INVAL_SUPP)))
			continue;

		ret = ib_peer_client->peer_mem->acquire(addr, size, NULL, NULL,
							peer_client_context);
		if (ret > 0)
			goto found;

		/* acquire returns 0 for failure cases, treat as no memory */
		if (!ret)
			ret = -ENOMEM;

		if (ret < 0) {
			mutex_unlock(&peer_memory_mutex);
			return ERR_PTR(ret);
		}
	}

	ib_peer_client = NULL;

found:
	mutex_unlock(&peer_memory_mutex);

	if (ib_peer_client)
		kref_get(&ib_peer_client->ref);

	return ib_peer_client;
}
EXPORT_SYMBOL(ib_get_peer_client);

void ib_put_peer_client(struct ib_peer_memory_client *ib_peer_client,
			void *peer_client_context)
{
	if (ib_peer_client->peer_mem->release)
		ib_peer_client->peer_mem->release(peer_client_context);

	kref_put(&ib_peer_client->ref, complete_peer);
}
EXPORT_SYMBOL(ib_put_peer_client);

struct ib_peer_mem_device *ib_peer_mem_add_device(struct ib_device *device)
{
	struct ib_peer_mem_device *mem_dev = NULL;

	/* Check if GPU direct capable */
	// return otherwise

	mem_dev = kmalloc(sizeof (*mem_dev), GFP_KERNEL);

	if (!mem_dev)
		return NULL;

	mem_dev->device = device;

	pr_info("adding ib device = 0x%llx\n", (u64)mem_dev);

	peer_mem_init_hash_tbl(mem_dev);
	mutex_init(&mem_dev->hash_lock);
	list_add_tail(&mem_dev->list, &peer_dev_list);
	return mem_dev;
}
EXPORT_SYMBOL(ib_peer_mem_add_device);

void ib_peer_mem_remove_device(struct ib_peer_mem_device *peer_mem_dev)
{
	if (!peer_mem_dev)
		return;

	list_del(&peer_mem_dev->list);
	pr_info("Removing ib device = 0x%llx\n", (u64)peer_mem_dev);
	kfree(peer_mem_dev);
}
EXPORT_SYMBOL(ib_peer_mem_remove_device);

static int __init  ib_peer_mem_mod_init(void)
{

	pr_debug("ib_peer_mem: Module Init\n");
	return 0;
}

static void __exit ib_peer_mem_mod_exit(void)
{
	pr_debug("ib_peer_mem: Module exit\n");
}

module_init(ib_peer_mem_mod_init);
module_exit(ib_peer_mem_mod_exit);

