Added Response dataclass for better openai's response deserialization.
This commit is contained in:
parent
ad59a695af
commit
b77dc31ea5
@ -69,30 +69,38 @@ class Prompt:
|
|||||||
self.input.append(message)
|
self.input.append(message)
|
||||||
|
|
||||||
|
|
||||||
def filter_fields(data: dict, filter_by_class: object) -> dict:
|
@dataclass(frozen=True)
|
||||||
return {k: v for k, v in data.items() if k in tuple(f.name for f in fields(filter_by_class))}
|
class Response:
|
||||||
|
output: list[Message]
|
||||||
|
|
||||||
|
|
||||||
|
def filter_fields(data: dict, filter_by_class: dataclass) -> dict:
|
||||||
|
return {
|
||||||
|
i_key: i_value
|
||||||
|
for i_key, i_value in data.items()
|
||||||
|
if i_key in tuple(f.name for f in fields(filter_by_class))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
|
|
||||||
test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')])
|
test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')])
|
||||||
test_user_message = Message(role='user', content=[InputText('Hi! How are you?')])
|
test_user_message = Message(role='user', content=[InputText('Hi! How are you?')])
|
||||||
test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message])
|
test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message])
|
||||||
response = client.responses.create(**asdict(test_prompt))
|
print(asdict(test_prompt))
|
||||||
response_output = response.output[0].to_dict()
|
response_raw = client.responses.create(**asdict(test_prompt))
|
||||||
response_message = Message(**filter_fields(response_output, Message))
|
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
||||||
print(response.to_dict().get('output'))
|
print(asdict(response))
|
||||||
test_prompt.add_message(response_message)
|
test_prompt.add_message(response.output[0])
|
||||||
|
|
||||||
test_user_message_2 = Message(role='user', content=[
|
test_user_message_2 = Message(role='user', content=[
|
||||||
InputText('Can you tell me what is on this picture?'),
|
InputText('Can you tell me what is on this picture?'),
|
||||||
InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg')
|
InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg')
|
||||||
])
|
])
|
||||||
test_prompt.add_message(test_user_message_2)
|
test_prompt.add_message(test_user_message_2)
|
||||||
response = client.responses.create(**asdict(test_prompt))
|
response_raw = client.responses.create(**asdict(test_prompt))
|
||||||
response_role = response.output[0].role
|
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
||||||
response_text = response.output[0].content[0].text
|
test_prompt.add_message(response.output[0])
|
||||||
response_content = OutputText(response_text)
|
|
||||||
response_message = Message(role=response_role, content=[response_content])
|
|
||||||
test_prompt.add_message(response_message)
|
|
||||||
print(asdict(test_prompt))
|
print(asdict(test_prompt))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user