#include "sst_internal.h"
#include <linux/module.h>
#include <linux/device.h>
#include <linux/fs.h>
#include <linux/uaccess.h>
#include <linux/cdev.h>

#define	SST_CHRDEV "the-universe"

static int sst_chrdev_major = 0;
static struct device *my_universe = NULL;

static int universe_open(struct inode *inode, struct file *file) {
	file->private_data = get_sst_info();

	return 0;
}

static int universe_release(struct inode *inode, struct file *file) {

	return 0;
}

static ssize_t universe_read(struct file *file, char __user *buf, size_t count,
		       loff_t *ppos) {
	struct sst_info *sst_info = (struct sst_info*)file->private_data;
	char *answer = NULL;
	int min = 0;
	size_t len = 0;

	if (sst_consume_answer(sst_info, &answer)) {
		pr_debug("Cannot read from answers!\n");
		return 0;
	}
	len = strlen(answer);
	sst_debug("About to copy %lu bytes of your answer at 0x%lx to the userspace\n", len, (uintptr_t)answer);
	min = min(len, count);
	if (min != len) {
		pr_err("Sorry, your buffer is %lu bytes too small.\n", len - min);
	}
	if (copy_to_user(buf, answer, min)) {
		pr_err("User copy failed!\n");
		kfree(answer);
		return -EFAULT;
	}
	sst_debug("Copied %u bytes of your answer to the userspace: %s\n", min, answer);
	*ppos += len;
	kfree(answer);
	return len;
}

static ssize_t universe_write(struct file *file, const char __user *buf, size_t count,
		       loff_t *ppos) {
	struct sst_info *sst_info = (struct sst_info*)file->private_data;
	char *buf_copy;
	int err;

	buf_copy = memdup_user_nul(buf, count);
	if (IS_ERR(buf_copy)) {
		return PTR_ERR(buf_copy);
	}
	err = sst_produce_question(sst_info, buf_copy);
	if (err) {
		pr_err("Weird! The universe is full.\n");
		return -ENOMEM;
	}
	sst_debug("Asked the universe a question...\n");
	return count;
}

static const struct file_operations universe_fops = {
	.owner		= THIS_MODULE,
	.read		= universe_read,
	.write		= universe_write,
	.open		= universe_open,
	.release	= universe_release,
};

static struct class universe_class = {
	.name		= SST_CHRDEV,
};

static int __init sst_chrdev_init(void) {
	int err;

	sst_chrdev_major = err = register_chrdev(0, SST_CHRDEV, &universe_fops);
	if (err < 0) {
		pr_err("Cannot register chrdev: %d\n", err);
		goto out;
	}
	err = class_register(&universe_class);
    if (err) {
		pr_err("Cannot register universe class: %d\n", err);
		goto out_chrdev;
	}
	my_universe = device_create(&universe_class, NULL, MKDEV(sst_chrdev_major, 0), NULL, SST_CHRDEV);
	if (IS_ERR(my_universe)) {
		err = PTR_ERR(my_universe);
		pr_err("Cannot create device: %d\n", err);
		goto out_class;
	}
	err = sst_init();
	if (err) {
		pr_err("Cannot init sst_common: %d\n", err);
		goto out_device;
	}
	pr_notice("Loaded module %s\n", KBUILD_MODNAME);
	return 0;

out_device:
	device_destroy(&universe_class, MKDEV(sst_chrdev_major, 0));
out_class:
	class_unregister(&universe_class);
out_chrdev:
	unregister_chrdev(sst_chrdev_major, SST_CHRDEV);
out:
	return err;
}

static void __exit sst_chrdev_exit(void) {
	sst_destroy();
	device_destroy(&universe_class, MKDEV(sst_chrdev_major, 0));
	class_unregister(&universe_class);
	unregister_chrdev(sst_chrdev_major, SST_CHRDEV);
	pr_notice("Unloaded module %s\n", KBUILD_MODNAME);
}

module_init(sst_chrdev_init);
module_exit(sst_chrdev_exit);
MODULE_LICENSE("LGPL");