1
0
linux/debian/patches/misc-ntsync7/0005-ntsync-Introduce-NTSYNC_IOC_CREATE_MUTEX.patch

223 lines
5.7 KiB
Diff
Raw Normal View History

2024-12-16 07:12:49 +03:00
From bcdeaefdc4b60e7845232c201427717df3a83277 Mon Sep 17 00:00:00 2001
From: Elizabeth Figura <zfigura@codeweavers.com>
2024-12-16 07:12:49 +03:00
Date: Fri, 13 Dec 2024 13:34:46 -0600
Subject: ntsync: Introduce NTSYNC_IOC_CREATE_MUTEX.
This corresponds to the NT syscall NtCreateMutant().
An NT mutex is recursive, with a 32-bit recursion counter. When acquired via
NtWaitForMultipleObjects(), the recursion counter is incremented by one. The OS
records the thread which acquired it.
The OS records the thread which acquired it. However, in order to keep this
driver self-contained, the owning thread ID is managed by user-space, and passed
as a parameter to all relevant ioctls.
The initial owner and recursion count, if any, are specified when the mutex is
created.
Signed-off-by: Elizabeth Figura <zfigura@codeweavers.com>
---
2024-12-16 07:12:49 +03:00
drivers/misc/ntsync.c | 74 +++++++++++++++++++++++++++++++++++--
include/uapi/linux/ntsync.h | 9 ++++-
2 files changed, 79 insertions(+), 4 deletions(-)
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -25,6 +25,7 @@
enum ntsync_type {
NTSYNC_TYPE_SEM,
+ NTSYNC_TYPE_MUTEX,
};
/*
@@ -55,6 +56,10 @@ struct ntsync_obj {
__u32 count;
__u32 max;
} sem;
+ struct {
+ __u32 count;
+ pid_t owner;
+ } mutex;
} u;
/*
@@ -92,6 +97,7 @@ struct ntsync_q_entry {
struct ntsync_q {
struct task_struct *task;
+ __u32 owner;
/*
* Protected via atomic_try_cmpxchg(). Only the thread that wins the
@@ -214,13 +220,17 @@ static void ntsync_unlock_obj(struct nts
((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \
(obj)->dev_locked))
-static bool is_signaled(struct ntsync_obj *obj)
+static bool is_signaled(struct ntsync_obj *obj, __u32 owner)
{
ntsync_assert_held(obj);
switch (obj->type) {
case NTSYNC_TYPE_SEM:
return !!obj->u.sem.count;
+ case NTSYNC_TYPE_MUTEX:
+ if (obj->u.mutex.owner && obj->u.mutex.owner != owner)
+ return false;
+ return obj->u.mutex.count < UINT_MAX;
}
WARN(1, "bad object type %#x\n", obj->type);
@@ -250,7 +260,7 @@ static void try_wake_all(struct ntsync_d
}
for (i = 0; i < count; i++) {
- if (!is_signaled(q->entries[i].obj)) {
+ if (!is_signaled(q->entries[i].obj, q->owner)) {
can_wake = false;
break;
}
@@ -264,6 +274,10 @@ static void try_wake_all(struct ntsync_d
case NTSYNC_TYPE_SEM:
obj->u.sem.count--;
break;
+ case NTSYNC_TYPE_MUTEX:
+ obj->u.mutex.count++;
+ obj->u.mutex.owner = q->owner;
+ break;
}
}
wake_up_process(q->task);
@@ -307,6 +321,30 @@ static void try_wake_any_sem(struct ntsy
}
}
+static void try_wake_any_mutex(struct ntsync_obj *mutex)
+{
+ struct ntsync_q_entry *entry;
+
+ ntsync_assert_held(mutex);
+ lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX);
+
+ list_for_each_entry(entry, &mutex->any_waiters, node) {
+ struct ntsync_q *q = entry->q;
+ int signaled = -1;
+
+ if (mutex->u.mutex.count == UINT_MAX)
+ break;
+ if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner)
+ continue;
+
+ if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
+ mutex->u.mutex.count++;
+ mutex->u.mutex.owner = q->owner;
+ wake_up_process(q->task);
+ }
+ }
+}
+
/*
* Actually change the semaphore state, returning -EOVERFLOW if it is made
* invalid.
2024-12-16 07:12:49 +03:00
@@ -451,6 +489,30 @@ static int ntsync_create_sem(struct ntsy
return fd;
}
+static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp)
+{
+ struct ntsync_mutex_args args;
+ struct ntsync_obj *mutex;
+ int fd;
+
+ if (copy_from_user(&args, argp, sizeof(args)))
+ return -EFAULT;
+
+ if (!args.owner != !args.count)
+ return -EINVAL;
+
+ mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX);
+ if (!mutex)
+ return -ENOMEM;
+ mutex->u.mutex.count = args.count;
+ mutex->u.mutex.owner = args.owner;
+ fd = ntsync_obj_get_fd(mutex);
2024-12-16 07:12:49 +03:00
+ if (fd < 0)
+ kfree(mutex);
+
2024-12-16 07:12:49 +03:00
+ return fd;
+}
+
static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd)
{
struct file *file = fget(fd);
2024-12-16 07:12:49 +03:00
@@ -520,7 +582,7 @@ static int setup_wait(struct ntsync_devi
struct ntsync_q *q;
__u32 i, j;
- if (args->pad[0] || args->pad[1] || args->pad[2] || (args->flags & ~NTSYNC_WAIT_REALTIME))
+ if (args->pad[0] || args->pad[1] || (args->flags & ~NTSYNC_WAIT_REALTIME))
return -EINVAL;
if (args->count > NTSYNC_MAX_WAIT_COUNT)
2024-12-16 07:12:49 +03:00
@@ -534,6 +596,7 @@ static int setup_wait(struct ntsync_devi
if (!q)
return -ENOMEM;
q->task = current;
+ q->owner = args->owner;
atomic_set(&q->signaled, -1);
q->all = all;
q->count = count;
2024-12-16 07:12:49 +03:00
@@ -576,6 +639,9 @@ static void try_wake_any_obj(struct ntsy
case NTSYNC_TYPE_SEM:
try_wake_any_sem(obj);
break;
+ case NTSYNC_TYPE_MUTEX:
+ try_wake_any_mutex(obj);
+ break;
}
}
2024-12-16 07:12:49 +03:00
@@ -765,6 +831,8 @@ static long ntsync_char_ioctl(struct fil
void __user *argp = (void __user *)parm;
switch (cmd) {
+ case NTSYNC_IOC_CREATE_MUTEX:
+ return ntsync_create_mutex(dev, argp);
case NTSYNC_IOC_CREATE_SEM:
return ntsync_create_sem(dev, argp);
case NTSYNC_IOC_WAIT_ALL:
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
2024-12-16 07:12:49 +03:00
@@ -15,6 +15,11 @@ struct ntsync_sem_args {
__u32 max;
};
+struct ntsync_mutex_args {
+ __u32 owner;
+ __u32 count;
+};
+
#define NTSYNC_WAIT_REALTIME 0x1
struct ntsync_wait_args {
2024-12-16 07:12:49 +03:00
@@ -23,7 +28,8 @@ struct ntsync_wait_args {
__u32 count;
__u32 index;
__u32 flags;
- __u32 pad[3];
+ __u32 owner;
+ __u32 pad[2];
};
#define NTSYNC_MAX_WAIT_COUNT 64
2024-12-16 07:12:49 +03:00
@@ -31,6 +37,7 @@ struct ntsync_wait_args {
#define NTSYNC_IOC_CREATE_SEM _IOW ('N', 0x80, struct ntsync_sem_args)
#define NTSYNC_IOC_WAIT_ANY _IOWR('N', 0x82, struct ntsync_wait_args)
#define NTSYNC_IOC_WAIT_ALL _IOWR('N', 0x83, struct ntsync_wait_args)
2024-12-16 07:12:49 +03:00
+#define NTSYNC_IOC_CREATE_MUTEX _IOW ('N', 0x84, struct ntsync_mutex_args)
2024-12-16 07:12:49 +03:00
#define NTSYNC_IOC_SEM_RELEASE _IOWR('N', 0x81, __u32)