Docs - Pyworker

Extension Guide

Creating your own PyWorker can be complex and challenging, with many potential pitfalls. If you need assistance with adding new PyWorkers, please don't hesitate to contact us.

This guide walks you through adding new backends. It is taken from the hello_world worker's README in the Vast PyWorker repository

There is a hello_world PyWorker implantation under workers/hello_world. This PyWorker is created for an LLM model server that runs on port 5001 has two API endpoints:

  1. /generate: generates an full response to the prompt and sends a JSON response
  2. /generate_stream: streams a response one token at a time

Both of these endpoints take the same API JSON payload:

{ "prompt": String, "max_response_tokens": Number | null }

We want the PyWorker to also expose two endpoints, for each of the above endpoints.

Structure #

All PyWorkers should have two files:

. └── workers └── hello_world ├── __init__.py ├── data_types.py # contains data types representing model API endpoints ├── server.py # contains endpoint handlers ├── client.py # a script to call an endpoint through the autoscaler └── test_load.py # script for load testing

All of the classes follow strict type hinting. It is recommended that you type hint all of your function. This will allow your IDE or VSCode with pyright plugin to find any type errors in your implementation. You can also install pyright with npm install pyright and run pyright in the root of the project to find any type errors.

data_Types.py #

data classes representing the model API are defined here. They must inherit from lib.data_types.ApiPayload. ApiPayload is an abstract class and you need to define several functions for it:

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 import dataclasses import random from typing import Dict, Any from transformers import AutoTokenizer # used to count tokens in a prompt import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model from lib.data_types import ApiPayload nltk.download("words") WORD_LIST = nltk.corpus.words.words() #### you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") @dataclasses.dataclass class InputData(ApiPayload): prompt: str max_response_tokens: int @classmethod def for_test(cls) -> "ApiPayload": """defines how create a payload for load testing""" prompt = " ".join(random.choices(WORD_LIST, k=int(250))) return cls(prompt=prompt, max_response_tokens=300) def generate_payload_json(self) -> Dict[str, Any]: """defines how to convert an ApiPayload to JSON that will be sent to model API""" return dataclasses.asdict(self) def count_workload(self) -> float: """defines how to calculate workload for a payload""" return len(tokenizer.tokenize(self.prompt)) @classmethod def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData": """ defines how to transform JSON data to AuthData and payload type, in this case `InputData` defined above represents the data sent to the model API. AuthData is data generated by autoscaler in order to authenticate payloads. In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker for more complicated examples """ errors = {} for param in inspect.signature(cls).parameters: if param not in json_msg: errors[param] = "missing parameter" if errors: raise JsonDataException(errors) return cls( **{ k: v for k, v in json_msg.items() if k in inspect.signature(cls).parameters } )

server.py #

For every model API endpoint you want to use, you must implement an EndpointHandler. This class handles incoming requests, processes them, sends them to the model API server, and finally returns an HTTP response. EndpointHandler has several abstract functions that must be implemented. Here, we implement two, one for /generate, and one for /generate_stream:

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 """ AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route. When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData work. When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON: { auth_data: AuthData, payload : InputData # defined above } """ from aiohttp import web from lib.data_types import EndpointHandler, JsonDataException from lib.server import start_server from .data_types import InputData #### This class is the implementer for the '/generate' endpoint of model API @dataclasses.dataclass class GenerateHandler(EndpointHandler[InputData]): @property def endpoint(self) -> str: # the API endpoint return "/generate" @classmethod def payload_cls(cls) -> Type[InputData]: """this function should just return ApiPayload subclass used by this handler""" return InputData def generate_payload_json(self, payload: InputData) -> Dict[str, Any]: """ defines how to convert `InputData` defined above, to JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but can be more complicated, See comfyui for an example """ return dataclasses.asdict(payload) def make_benchmark_payload(self) -> InputData: """ defines how to generate an InputData for benchmarking. This needs to be defined in only one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test() method on InputData. However, in some cases you might need to fine tune your InputData used for benchmarking to closely resemble the average request users call the endpoint with in order to get best autoscaling performance """ return InputData.for_test() async def generate_client_response( self, client_request: web.Request, model_response: ClientResponse ) -> Union[web.Response, web.StreamResponse]: """ defines how to convert a model API response to a response to PyWorker client """ _ = client_request match model_response.status: case 200: log.debug("SUCCESS") data = await model_response.json() return web.json_response(data=data) case code: log.debug("SENDING RESPONSE: ERROR: unknown code") return web.Response(status=code)

We also handle GenerateStreamHandler for streaming responses. It is identical to GenerateHandler, except for the endpoint name and how we create a web response, as it is a streaming response:

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class GenerateStreamHandler(EndpointHandler[InputData]): @property def endpoint(self) -> str: return "/generate_stream" @classmethod def payload_cls(cls) -> Type[InputData]: return InputData def generate_payload_json(self, payload: InputData) -> Dict[str, Any]: return dataclasses.asdict(payload) def make_benchmark_payload(self) -> InputData: return InputData.for_test() async def generate_client_response( self, client_request: web.Request, model_response: ClientResponse ) -> Union[web.Response, web.StreamResponse]: match model_response.status: case 200: log.debug("Streaming response...") res = web.StreamResponse() res.content_type = "text/event-stream" await res.prepare(client_request) async for chunk in model_response.content: await res.write(chunk) await res.write_eof() log.debug("Done streaming response") return res case code: log.debug("SENDING RESPONSE: ERROR: unknown code") return web.Response(status=code)

You can now instantiate a Backend and use it to handle requests.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 from lib.backend import Backend, LogAction #### the url and port of model API MODEL_SERVER_URL = "http://0.0.0.0:5001" #### This is the log line that is emitted once the server has started MODEL_SERVER_START_LOG_MSG = "server has started" MODEL_SERVER_ERROR_LOG_MSGS = [ "Exception: corrupted model file" # message in the logs indicating the unrecoverable error ] backend = Backend( model_server_url=MODEL_SERVER_URL, # location of model log file model_log_file=os.environ["MODEL_LOG"], # for some model backends that can only handle one request at a time, be sure to set this to False to # let PyWorker handling queueing requests. allow_parallel_requests=True, # give the backend an EndpointHandler instance that is used for benchmarking # number of benchmark run and number of words for a random benchmark run are given benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), # defines how to handle specific log messages. See docstring of LogAction for details log_actions=[ (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG), (LogAction.Info, '"message":"Download'), *[ (LogAction.ModelError, error_msg) for error_msg in MODEL_SERVER_ERROR_LOG_MSGS ], ], ) #### this is a simple ping handler for PyWorker async def handle_ping(_: web.Request): return web.Response(body="pong") #### this is a handler for forwarding a health check to model API async def handle_healthcheck(_: web.Request): healthcheck_res = await backend.session.get("/healthcheck") return web.Response(body=healthcheck_res.content, status=healthcheck_res.status) routes = [ web.post("/generate", backend.create_handler(GenerateHandler())), web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())), web.get("/ping", handle_ping), web.get("/healthcheck", handle_healthcheck), ] if __name__ == "__main__": # start server, called from start_server.sh start_server(backend, routes)

test_load.py #

Here you can create a script that allows you test an endpoint group running instances with this PyWorker

1 2 3 4 5 6 7 from lib.test_harness import run from .data_types import InputData WORKER_ENDPOINT = "/generate" if __name__ == "__main__": run(InputData.for_test(), WORKER_ENDPOINT)

You can then run the following command from the root of this repo to load test endpoint group:

1 2 #### sends 1000 requests at the rate of 0.5 requests per second python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"