Skip to content

Distributed #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.env
filtered/__pycache__
24 changes: 24 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
FROM --platform=linux/x86_64 ubuntu:24.04

ENV DEBIAN_FRONTEND=noninteractive
ENV TZ=UTC

RUN apt-get update && \
apt-get install -y \
wget \
xz-utils \
bzip2 \
git \
python3-pip \
python3 \
&& apt-get install -y software-properties-common \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

COPY requirements.txt .

RUN pip install -r requirements.txt --break-system-packages

COPY filtered/ ./filtered/

CMD ["python3.11", "-m", "celery", "-A", "filtered.worker", "worker", "--loglevel=info", "--concurrency=1"]
4 changes: 4 additions & 0 deletions filtered/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .distributed import *
from .filter import *
from .split import *
from .worker import *
85 changes: 85 additions & 0 deletions filtered/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from distributask.distributask import create_from_config
from .filter import read_json_in_batches
from .worker import run_job
from tqdm import tqdm
import time

if __name__ == "__main__":

input_filename = "datasets/cap3d_captions.json"
batch_size = 10000

distributask = create_from_config()

max_price = 0.25
max_nodes = 25
docker_image = "antbaez/filter-worker:latest"
module_name = "filtered.worker"

redis_client = distributask.get_redis_connection()

rented_nodes = distributask.rent_nodes(
max_price, max_nodes, docker_image, module_name
)
print("Total nodes rented: ", len(rented_nodes))

distributask.register_function(run_job)

while True:
user_input = input("press r when workers are ready: ")
if user_input == "r":
break

total_batches = 0

print("Sending tasks")
tasks = []

json_batches = [batch for batch in read_json_in_batches(input_filename, batch_size)]
print(f"number of batches: {len(json_batches)}")

num_batches = len(json_batches)
for i in range(num_batches):

batch = json_batches[i]
total_batches += 1

print(total_batches)
task = distributask.execute_function(
"run_job", {"batch_index": total_batches, "batch": batch}
)

tasks.append(task)

first_task_done = False
print("Tasks sent. Starting monitoring")

inactivity_log = {node["instance_id"]: 0 for node in rented_nodes}

start_time = time.time()
with tqdm(total=len(tasks), unit="task") as pbar:
while not all(task.ready() for task in tasks):

current_tasks = sum([task.ready() for task in tasks])
pbar.update(current_tasks - pbar.n)

time.sleep(1)

current_time = time.time()
if current_time - start_time > 60:
start_time = time.time()

for node in rented_nodes:
log_response = distributask.get_node_log(node)
if log_response.status_code == 200:
try:
last_msg = log_response.text.splitlines()[-1]
if ("Task complete" in last_msg and inactivity_log[node["instance_id"]] == 0):
inactivity_log[node["instance_id"]] = 1
elif ("Task complete" in last_msg and inactivity_log[node["instance_id"]] == 1):
distributask.terminate_nodes([node])
print("node terminated")
else:
inactivity_log[node["instance_id"]] == 0
except:
pass
6 changes: 3 additions & 3 deletions filtered/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def detect_objects(self, caption):
if not objects:
objects = self._extract_noun_phrases(caption)

print("These are the objects:", objects)
# print("These are the objects:", objects)
return objects

def _extract_noun_phrases(self, text):
Expand Down Expand Up @@ -147,12 +147,12 @@ def test_caption_filtering():

if total_filtered_count >= write_batch_size or current_batch == total_batches:
write_filtered_json(output_filename, filtered_data, first_batch=first_batch, last_batch=(current_batch == total_batches))
print(f"Wrote batch {current_batch}/{total_batches} with {total_filtered_count} filtered captions")
# print(f"Wrote batch {current_batch}/{total_batches} with {total_filtered_count} filtered captions")
filtered_data = {}
total_filtered_count = 0
first_batch = False

print("Filtering and writing completed.")
# print("Filtering and writing completed.")

# Optionally, you can keep the test function call if you want to run tests
# test_caption_filtering()
27 changes: 27 additions & 0 deletions filtered/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys
from .filter import filter_captions, write_filtered_json

def run_job(batch_index, batch):

if len(str(batch_index)) == 1:
batch_num = f"0{batch_index}"
else:
batch_num = f"{batch_index}"

output_filename = f"batch_{batch_num}"

filtered_batch = filter_captions(batch)
write_filtered_json(output_filename, filtered_batch)

distributask.upload_file(output_filename)

return "Task complete"


if __name__ == "__main__" or any("celery" in arg for arg in sys.argv):
from distributask.distributask import create_from_config

distributask = create_from_config()
distributask.register_function(run_job)

celery = distributask.app
10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
requests
fsspec
celery
redis
huggingface_hub
python-dotenv
omegaconf
tqdm
gliner
distributask