pytorch-http/__main__.py
2023-09-23 15:10:35 -05:00

45 lines
1020 B
Python

from bottle import request, response, post, template, HTTPResponse, run
from dotenv import load_dotenv
from os import getenv
import torch
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained(
"google/flan-t5-xl",
cache_dir = "model",
)
model = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-xl",
)
load_dotenv()
api_keys = getenv('API_KEYS').split(sep=",")
port = getenv('PORT') or 9010
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
@post('/')
def gen():
auth = request.get_header('authorization')
if auth is None:
raise HTTPResponse(status = 401)
scheme, val = auth.split(sep=" ")
if scheme != 'X-Api-Key' or val not in api_keys:
raise HTTPResponse(status = 401)
body = request.json
input = body["input"]
output = generator(input)
print(input)
print(output)
return {"output": output[0]["generated_text"]}
if __name__ == "__main__":
run(host='0.0.0.0', port=9010)