mirror of
https://github.com/cirosantilli/linux-kernel-module-cheat.git
synced 2026-01-26 03:31:36 +01:00
thread_pool: support passing thread IDs
Then use that to fix gem5 error log read race.
This commit is contained in:
@@ -31,6 +31,7 @@ class ThreadPool:
|
||||
python3 thread_pool.py 2 -10 20 1
|
||||
python3 thread_pool.py 2 -10 20 2
|
||||
python3 thread_pool.py 2 -10 20 3
|
||||
python3 thread_pool.py 2 -10 20 0 1
|
||||
....
|
||||
|
||||
These ensure that execution stops neatly on error.
|
||||
@@ -39,7 +40,8 @@ class ThreadPool:
|
||||
self,
|
||||
func: Callable,
|
||||
handle_output: Union[Callable[[Any,Any,Exception],Any],None] = None,
|
||||
nthreads: Union[int,None] = None
|
||||
nthreads: Union[int,None] = None,
|
||||
thread_id_arg: Union[str,None] = None,
|
||||
):
|
||||
'''
|
||||
Start in a thread pool immediately.
|
||||
@@ -62,6 +64,9 @@ class ThreadPool:
|
||||
|
||||
Default: a handler that does nothing and just exits on exception.
|
||||
:param nthreads: number of threads to use. Default: nproc.
|
||||
:param thread_id_arg: if not None, set the argument of func with this name
|
||||
to a 0-indexed thread ID. This allows function calls to coordinate
|
||||
usage of external resources such as files or ports.
|
||||
'''
|
||||
self.func = func
|
||||
if handle_output is None:
|
||||
@@ -69,6 +74,7 @@ class ThreadPool:
|
||||
self.handle_output = handle_output
|
||||
if nthreads is None:
|
||||
nthreads = len(os.sched_getaffinity(0))
|
||||
self.thread_id_arg = thread_id_arg
|
||||
self.nthreads = nthreads
|
||||
self.error_output = None
|
||||
self.error_output_lock = threading.Lock()
|
||||
@@ -77,6 +83,7 @@ class ThreadPool:
|
||||
for i in range(self.nthreads):
|
||||
thread = threading.Thread(
|
||||
target=self._func_runner,
|
||||
args=(i,)
|
||||
)
|
||||
self.threads.append(thread)
|
||||
thread.start()
|
||||
@@ -123,11 +130,13 @@ class ThreadPool:
|
||||
thread.join()
|
||||
return self.error_output
|
||||
|
||||
def _func_runner(self):
|
||||
def _func_runner(self, thread_id):
|
||||
while True:
|
||||
work = self.in_queue.get(block=True)
|
||||
if work is None:
|
||||
break
|
||||
if self.thread_id_arg is not None:
|
||||
work[self.thread_id_arg] = thread_id
|
||||
try:
|
||||
exception = None
|
||||
out = self.func(**work)
|
||||
@@ -147,7 +156,7 @@ class ThreadPool:
|
||||
self.in_queue.task_done()
|
||||
|
||||
if __name__ == '__main__':
|
||||
def my_func(i):
|
||||
def func_maybe_raise(i):
|
||||
'''
|
||||
The main function that will be evaluated.
|
||||
|
||||
@@ -156,6 +165,10 @@ if __name__ == '__main__':
|
||||
time.sleep((abs(i) % 4) / 10.0)
|
||||
return 10.0 / i
|
||||
|
||||
def func_get_thread(i, thread_id):
|
||||
time.sleep((abs(i) % 4) / 10.0)
|
||||
return thread_id
|
||||
|
||||
def get_work(min_, max_):
|
||||
'''
|
||||
Generate simple range work for my_func.
|
||||
@@ -202,14 +215,17 @@ if __name__ == '__main__':
|
||||
nthreads = None
|
||||
else:
|
||||
nthreads = None
|
||||
|
||||
if argv_len > 2:
|
||||
min_ = int(sys.argv[2])
|
||||
else:
|
||||
min_ = 1
|
||||
|
||||
if argv_len > 3:
|
||||
max_ = int(sys.argv[3])
|
||||
else:
|
||||
max_ = 100
|
||||
|
||||
if argv_len > 4:
|
||||
c = sys.argv[4][0]
|
||||
else:
|
||||
@@ -223,11 +239,23 @@ if __name__ == '__main__':
|
||||
else:
|
||||
handle_output = handle_output_print
|
||||
|
||||
if argv_len > 5:
|
||||
c = sys.argv[5][0]
|
||||
else:
|
||||
c = '0'
|
||||
if c == '1':
|
||||
my_func = func_get_thread
|
||||
thread_id_arg = 'thread_id'
|
||||
else:
|
||||
my_func = func_maybe_raise
|
||||
thread_id_arg = None
|
||||
|
||||
# Action.
|
||||
thread_pool = ThreadPool(
|
||||
my_func,
|
||||
handle_output,
|
||||
nthreads
|
||||
nthreads,
|
||||
thread_id_arg,
|
||||
)
|
||||
for work in get_work(min_, max_):
|
||||
error = thread_pool.submit(work)
|
||||
|
||||
Reference in New Issue
Block a user