Added Response dataclass for better openai's response deserialization.
This commit is contained in:
parent
ad59a695af
commit
b77dc31ea5
@ -69,30 +69,38 @@ class Prompt:
|
||||
self.input.append(message)
|
||||
|
||||
|
||||
def filter_fields(data: dict, filter_by_class: object) -> dict:
|
||||
return {k: v for k, v in data.items() if k in tuple(f.name for f in fields(filter_by_class))}
|
||||
@dataclass(frozen=True)
|
||||
class Response:
|
||||
output: list[Message]
|
||||
|
||||
|
||||
def filter_fields(data: dict, filter_by_class: dataclass) -> dict:
|
||||
return {
|
||||
i_key: i_value
|
||||
for i_key, i_value in data.items()
|
||||
if i_key in tuple(f.name for f in fields(filter_by_class))
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from openai import OpenAI
|
||||
client = OpenAI()
|
||||
|
||||
test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')])
|
||||
test_user_message = Message(role='user', content=[InputText('Hi! How are you?')])
|
||||
test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message])
|
||||
response = client.responses.create(**asdict(test_prompt))
|
||||
response_output = response.output[0].to_dict()
|
||||
response_message = Message(**filter_fields(response_output, Message))
|
||||
print(response.to_dict().get('output'))
|
||||
test_prompt.add_message(response_message)
|
||||
print(asdict(test_prompt))
|
||||
response_raw = client.responses.create(**asdict(test_prompt))
|
||||
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
||||
print(asdict(response))
|
||||
test_prompt.add_message(response.output[0])
|
||||
|
||||
test_user_message_2 = Message(role='user', content=[
|
||||
InputText('Can you tell me what is on this picture?'),
|
||||
InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg')
|
||||
])
|
||||
test_prompt.add_message(test_user_message_2)
|
||||
response = client.responses.create(**asdict(test_prompt))
|
||||
response_role = response.output[0].role
|
||||
response_text = response.output[0].content[0].text
|
||||
response_content = OutputText(response_text)
|
||||
response_message = Message(role=response_role, content=[response_content])
|
||||
test_prompt.add_message(response_message)
|
||||
response_raw = client.responses.create(**asdict(test_prompt))
|
||||
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
||||
test_prompt.add_message(response.output[0])
|
||||
print(asdict(test_prompt))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user