-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
executable file
·81 lines (66 loc) · 2.45 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
from typing import Optional
from pathlib import Path
from functools import partial
from io import BytesIO
import base64
from PIL import Image
import click
from chex import Array
import haiku as hk
import jax
from flask import Flask, Response, request
import numpy as np
from kigo.utils import get_logger, get_rngs
from kigo import persistence
from kigo import nn
from kigo import diffusion
from kigo.configs import Config
logger = get_logger('kigo.server')
@click.command('Server')
@click.argument('checkpoint', type=persistence.get_checkpoint)
@click.option('--port', '-p', type=int, default=8080)
@click.option('--host', '-h', type=str, default='127.0.0.1')
@click.option('--seed', type=int, default=None)
def cli(checkpoint: Path, port: int, host: str, seed: Optional[int]) -> None:
'''A server wrapping Kigo!'''
rngs = get_rngs(seed)
cfg = persistence.load_cfg(checkpoint)
params = persistence.load_ema(checkpoint)
sampler = diffusion.Sampler(params, partial(nn.Model.from_cfg, cfg))
app = build_app(sampler, cfg, rngs)
app.run(host=host, port=port)
def build_app(sampler: diffusion.Sampler,
cfg: Config,
rngs: hk.PRNGSequence) -> Flask:
app = Flask(__name__)
@app.route('/', methods=['GET'])
def index() -> Response:
# Validation
steps = int(request.args.get('steps', 64))
if not (0 < steps <= 1024):
return Response('Invalid "steps" value: 0 < steps <= 1024.', 400)
eta = float(request.args.get('eta', 0.1))
if not (0. <= eta <= 1.):
return Response('Invalid "eta" value: 0 <= eta <= 1.', 400)
clip = float(request.args.get('clip', 0.1))
if not (0. <= clip <= 1.):
return Response('Invalid "clip" value: 0 <= clip <= 1.', 400)
# Run sampler
xT = jax.random.normal(next(rngs), shape=(1, *cfg.img.shape))
x0 = sampler.sample_p(xT, steps, next(rngs), eta, clip)
return Response(to_html(x0))
return app
def to_base64_img(xt: Array) -> str:
xt_np = np.asarray(xt).squeeze(0) * 0.5 + 0.5
img = Image.fromarray((255 * xt_np).astype(np.uint8))
buffer = BytesIO()
img.save(buffer, format='PNG')
buffer.seek(0)
img_byte = buffer.getvalue()
img_str = base64.b64encode(img_byte).decode()
return "data:image/png;base64," + img_str
def to_html(xt: Array) -> str:
return f'<img src="{to_base64_img(xt)}"/>'
if __name__ == '__main__':
cli()