#!/usr/bin/env python3 from typing import Any, Callable, Dict, Iterable, Union import os import queue import sys import threading import time class ThreadPool: ''' Start a pool of a limited number of threads to do some work. This is similar to the stdlib concurrent, but I could not find how to reach all my design goals with that implementation: * the input function does not need to be modified * limit the number of threads * queue sizes closely follow number of threads * if an exception happens, optionally stop soon afterwards Functional form and further discussion at: https://stackoverflow.com/questions/19369724/the-right-way-to-limit-maximum-number-of-threads-running-at-once/55263676#55263676 This class form allows to use your own while loops with submit(). Quick test with: .... python3 thread_pool.py 2 -10 20 0 python3 thread_pool.py 2 -10 20 1 python3 thread_pool.py 2 -10 20 2 python3 thread_pool.py 2 -10 20 3 python3 thread_pool.py 2 -10 20 0 1 .... These ensure that execution stops neatly on error. ''' def __init__( self, func: Callable, handle_output: Union[Callable[[Any,Any,Exception],Any],None] = None, nthreads: Union[int,None] = None, thread_id_arg: Union[str,None] = None, ): ''' Start in a thread pool immediately. join() must be called afterwards at some point. :param func: main work function to be evaluated. :param handle_output: called on func return values as they are returned. Signature is: handle_output(input, output, exception) where: * input: input given to func * output: return value of func * exception: the exception that func raised, or None otherwise If this function returns non-None or raises, stop feeding new input and exit ASAP when all currently running threads have finished. Default: a handler that does nothing and just exits on exception. :param nthreads: number of threads to use. Default: nproc. :param thread_id_arg: if not None, set the argument of func with this name to a 0-indexed thread ID. This allows function calls to coordinate usage of external resources such as files or ports. ''' self.func = func if handle_output is None: handle_output = lambda input, output, exception: exception self.handle_output = handle_output if nthreads is None: nthreads = len(os.sched_getaffinity(0)) self.thread_id_arg = thread_id_arg self.nthreads = nthreads self.error_output = None self.error_output_lock = threading.Lock() self.in_queue = queue.Queue(maxsize=nthreads) self.threads = [] for i in range(self.nthreads): thread = threading.Thread( target=self._func_runner, args=(i,) ) self.threads.append(thread) thread.start() def __enter__(self): ''' __exit__ automatically calls join() for you. This is cool because it automatically ends the loop if an exception occurs. But don't forget that errors may happen after the last submit is called, so you likely want to check for that with get_error after the with. get_error() returns the same as the explicit join(). ''' return self def __exit__(self, type, value, traceback): self.join() def get_error(self): return self.error_output def submit(self, work): ''' Submit work. Block if there is already enough work scheduled (~nthreads). :return: if an error occurred in some previously executed thread, the error. Otherwise, None. This allows the caller to stop submitting further work if desired. ''' self.in_queue.put(work) return self.error_output def join(self): ''' Request all threads to stop after they finish currently submitted work. :return: same as submit() ''' for thread in range(self.nthreads): self.in_queue.put(None) for thread in self.threads: thread.join() return self.error_output def _func_runner(self, thread_id): while True: work = self.in_queue.get(block=True) if work is None: break if self.thread_id_arg is not None: work[self.thread_id_arg] = thread_id try: exception = None out = self.func(**work) except Exception as e: exception = e out = None try: handle_output_return = self.handle_output(work, out, exception) except Exception as e: with self.error_output_lock: self.error_output = (work, out, e) else: if handle_output_return is not None: with self.error_output_lock: self.error_output = handle_output_return finally: self.in_queue.task_done() if __name__ == '__main__': def func_maybe_raise(i): ''' The main function that will be evaluated. It sleeps to simulate an IO operation. ''' time.sleep((abs(i) % 4) / 10.0) return 10.0 / i def func_get_thread(i, thread_id): time.sleep((abs(i) % 4) / 10.0) return thread_id def get_work(min_, max_): ''' Generate simple range work for my_func. ''' for i in range(min_, max_): yield {'i': i} def handle_output_print(input, output, exception): ''' Print outputs and exit immediately on failure. ''' print('{!r} {!r} {!r}'.format(input, output, exception)) return exception def handle_output_print_no_exit(input, output, exception): ''' Print outputs, don't exit on failure. ''' print('{!r} {!r} {!r}'.format(input, output, exception)) out_queue = queue.Queue() def handle_output_queue(input, output, exception): ''' Store outputs in a queue for later usage. ''' global out_queue out_queue.put((input, output, exception)) return exception def handle_output_raise(input, output, exception): ''' Raise if input == 10, to test that execution stops nicely if this raises. ''' print('{!r} {!r} {!r}'.format(input, output, exception)) if input['i'] == 10: raise Exception # CLI arguments. argv_len = len(sys.argv) if argv_len > 1: nthreads = int(sys.argv[1]) if nthreads == 0: nthreads = None else: nthreads = None if argv_len > 2: min_ = int(sys.argv[2]) else: min_ = 1 if argv_len > 3: max_ = int(sys.argv[3]) else: max_ = 100 if argv_len > 4: c = sys.argv[4][0] else: c = '0' if c == '1': handle_output = handle_output_print_no_exit elif c == '2': handle_output = handle_output_queue elif c == '3': handle_output = handle_output_raise else: handle_output = handle_output_print if argv_len > 5: c = sys.argv[5][0] else: c = '0' if c == '1': my_func = func_get_thread thread_id_arg = 'thread_id' else: my_func = func_maybe_raise thread_id_arg = None # Action. thread_pool = ThreadPool( my_func, handle_output, nthreads, thread_id_arg, ) for work in get_work(min_, max_): error = thread_pool.submit(work) if error is not None: break error = thread_pool.join() if error is not None: print('error: {!r}'.format(error)) if handle_output == handle_output_queue: while not out_queue.empty(): print(out_queue.get())