import http.server
import urllib.request
import json
import threading
import multiprocessing as mp
from constants import *


def encode(data):
    return json.dumps(data).encode('utf-8')


def decode(s):
    return json.loads(s.decode('utf-8'))


class ATask:
    def __init__(self, task_id):
        self.task_id = task_id

    def run(self):
        i = 0
        for _ in range(Y):
            url = f'http://localhost:{B_PORT}'
            data = encode({'task_id': self.task_id, 'i': i})
            while True:   # see report.txt for explanation
                try:
                    with urllib.request.urlopen(url, data) as r:
                        response = decode(r.read())
                    break
                except ConnectionResetError:
                    pass
            assert response['task_id'] == self.task_id
            i = response['i']
        assert i == Y


class A:
    class HandlerOK(http.server.BaseHTTPRequestHandler):
        def do_GET(self):
            self.send_response(200)
            self.end_headers()

        def log_message(self, *args):
            pass  # silence request logging

    def run(self):
        server = http.server.HTTPServer(('', A_PORT), A.HandlerOK)
        server_thread = threading.Thread(target=server.serve_forever)
        server_thread.start()

        b_ready.wait()
        task_threads = [threading.Thread(target=ATask(task_id).run) for task_id in range(X)]
        for thread in task_threads:
            thread.start()
        for thread in task_threads:
            thread.join()

        server.shutdown()
        server_thread.join()
        a_done.set()


class B:
    class HandlerInc(http.server.BaseHTTPRequestHandler):
        def do_POST(self):
            with urllib.request.urlopen(f'http://localhost:{A_PORT}') as conn:
                conn.read()
            content_length = int(self.headers['Content-Length'])
            data = decode(self.rfile.read(content_length))
            data['i'] += 1
            self.send_response(200)
            self.end_headers()
            self.wfile.write(encode(data))

        def log_message(self, *args):
            pass  # silence request logging

    def run(self):
        # ThreadingHTTPServer handles creating threads for us
        server = http.server.ThreadingHTTPServer(('', B_PORT), B.HandlerInc)
        server_thread = threading.Thread(target=server.serve_forever)
        server_thread.start()
        b_ready.set()

        a_done.wait()

        server.shutdown()
        server_thread.join()


# these are global (we could pass them to A(), B() constructors, but let's keep it simple)
b_ready = mp.Event()
a_done = mp.Event()

a = mp.Process(target=A().run)
b = mp.Process(target=B().run)
a.start()
b.start()
a.join()
b.join()
