Creating your own PyWorker can be complex and challenging, with many potential pitfalls. If you need assistance with adding new PyWorkers, please don't hesitate to contact us.
This guide walks you through adding new backends. It is taken from the hello_world worker's README in the Vast PyWorker repository
There is a hello_world PyWorker implantation under workers/hello_world
. This PyWorker is
created for an LLM model server that runs on port 5001 has two API endpoints:
/generate
: generates an full response to the prompt and sends a JSON response/generate_stream
: streams a response one token at a timeBoth of these endpoints take the same API JSON payload:
{
"prompt": String,
"max_response_tokens": Number | null
}
We want the PyWorker to also expose two endpoints, for each of the above endpoints.
All PyWorkers should have two files:
.
└── workers
└── hello_world
├── __init__.py
├── data_types.py # contains data types representing model API endpoints
├── server.py # contains endpoint handlers
├── client.py # a script to call an endpoint through the autoscaler
└── test_load.py # script for load testing
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
This will allow your IDE or VSCode with pyright
plugin to find any type errors in your implementation.
You can also install pyright
with npm install pyright
and run pyright
in the root of the project to find
any type errors.
data classes representing the model API are defined here. They must inherit from
lib.data_types.ApiPayload
. ApiPayload
is an abstract class and you need to define several functions for it:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import dataclasses
import random
from typing import Dict, Any
from transformers import AutoTokenizer # used to count tokens in a prompt
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
from lib.data_types import ApiPayload
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
#### you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "ApiPayload":
"""defines how create a payload for load testing"""
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
return dataclasses.asdict(self)
def count_workload(self) -> float:
"""defines how to calculate workload for a payload"""
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
"""
defines how to transform JSON data to AuthData and payload type,
in this case `InputData` defined above represents the data sent to the model API.
AuthData is data generated by autoscaler in order to authenticate payloads.
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
for more complicated examples
"""
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
For every model API endpoint you want to use, you must implement an EndpointHandler
. This class handles incoming
requests, processes them, sends them to the model API server, and finally returns an HTTP response.
EndpointHandler
has several abstract functions that must be implemented. Here, we implement two, one
for /generate
, and one for /generate_stream
:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
work.
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
{
auth_data: AuthData,
payload : InputData # defined above
}
"""
from aiohttp import web
from lib.data_types import EndpointHandler, JsonDataException
from lib.server import start_server
from .data_types import InputData
#### This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@classmethod
def payload_cls(cls) -> Type[InputData]:
"""this function should just return ApiPayload subclass used by this handler"""
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
can be more complicated, See comfyui for an example
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
method on InputData. However, in some cases you might need to fine tune your InputData used for
benchmarking to closely resemble the average request users call the endpoint with in order to get best
autoscaling performance
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
We also handle GenerateStreamHandler
for streaming responses. It is identical to GenerateHandler
, except for
the endpoint name and how we create a web response, as it is a streaming response:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
You can now instantiate a Backend and use it to handle requests.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from lib.backend import Backend, LogAction
#### the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
#### This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
backend = Backend(
model_server_url=MODEL_SERVER_URL,
# location of model log file
model_log_file=os.environ["MODEL_LOG"],
# for some model backends that can only handle one request at a time, be sure to set this to False to
# let PyWorker handling queueing requests.
allow_parallel_requests=True,
# give the backend an EndpointHandler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
#### this is a simple ping handler for PyWorker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
#### this is a handler for forwarding a health check to model API
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start server, called from start_server.sh
start_server(backend, routes)
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
1
2
3
4
5
6
7
from lib.test_harness import run
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
run(InputData.for_test(), WORKER_ENDPOINT)
You can then run the following command from the root of this repo to load test endpoint group:
1
2
#### sends 1000 requests at the rate of 0.5 requests per second
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"