thread_pool: support passing thread IDs

Then use that to fix gem5 error log read race.
This commit is contained in:
Ciro Santilli 六四事件 法轮功
2019-05-05 00:00:00 +00:00
parent b49ebb1c8a
commit 5daad53289
5 changed files with 52 additions and 16 deletions

View File

@@ -31,6 +31,7 @@ class ThreadPool:
python3 thread_pool.py 2 -10 20 1
python3 thread_pool.py 2 -10 20 2
python3 thread_pool.py 2 -10 20 3
python3 thread_pool.py 2 -10 20 0 1
....
These ensure that execution stops neatly on error.
@@ -39,7 +40,8 @@ class ThreadPool:
self,
func: Callable,
handle_output: Union[Callable[[Any,Any,Exception],Any],None] = None,
nthreads: Union[int,None] = None
nthreads: Union[int,None] = None,
thread_id_arg: Union[str,None] = None,
):
'''
Start in a thread pool immediately.
@@ -62,6 +64,9 @@ class ThreadPool:
Default: a handler that does nothing and just exits on exception.
:param nthreads: number of threads to use. Default: nproc.
:param thread_id_arg: if not None, set the argument of func with this name
to a 0-indexed thread ID. This allows function calls to coordinate
usage of external resources such as files or ports.
'''
self.func = func
if handle_output is None:
@@ -69,6 +74,7 @@ class ThreadPool:
self.handle_output = handle_output
if nthreads is None:
nthreads = len(os.sched_getaffinity(0))
self.thread_id_arg = thread_id_arg
self.nthreads = nthreads
self.error_output = None
self.error_output_lock = threading.Lock()
@@ -77,6 +83,7 @@ class ThreadPool:
for i in range(self.nthreads):
thread = threading.Thread(
target=self._func_runner,
args=(i,)
)
self.threads.append(thread)
thread.start()
@@ -123,11 +130,13 @@ class ThreadPool:
thread.join()
return self.error_output
def _func_runner(self):
def _func_runner(self, thread_id):
while True:
work = self.in_queue.get(block=True)
if work is None:
break
if self.thread_id_arg is not None:
work[self.thread_id_arg] = thread_id
try:
exception = None
out = self.func(**work)
@@ -147,7 +156,7 @@ class ThreadPool:
self.in_queue.task_done()
if __name__ == '__main__':
def my_func(i):
def func_maybe_raise(i):
'''
The main function that will be evaluated.
@@ -156,6 +165,10 @@ if __name__ == '__main__':
time.sleep((abs(i) % 4) / 10.0)
return 10.0 / i
def func_get_thread(i, thread_id):
time.sleep((abs(i) % 4) / 10.0)
return thread_id
def get_work(min_, max_):
'''
Generate simple range work for my_func.
@@ -202,14 +215,17 @@ if __name__ == '__main__':
nthreads = None
else:
nthreads = None
if argv_len > 2:
min_ = int(sys.argv[2])
else:
min_ = 1
if argv_len > 3:
max_ = int(sys.argv[3])
else:
max_ = 100
if argv_len > 4:
c = sys.argv[4][0]
else:
@@ -223,11 +239,23 @@ if __name__ == '__main__':
else:
handle_output = handle_output_print
if argv_len > 5:
c = sys.argv[5][0]
else:
c = '0'
if c == '1':
my_func = func_get_thread
thread_id_arg = 'thread_id'
else:
my_func = func_maybe_raise
thread_id_arg = None
# Action.
thread_pool = ThreadPool(
my_func,
handle_output,
nthreads
nthreads,
thread_id_arg,
)
for work in get_work(min_, max_):
error = thread_pool.submit(work)