diff --git a/examples/system_instruction.rs b/examples/system_instruction.rs index 8330f33..1b57cec 100644 --- a/examples/system_instruction.rs +++ b/examples/system_instruction.rs @@ -18,17 +18,10 @@ async fn main() -> Result<(), Box> { let system_instruction = "Answer as if you were Yoda"; let prompt = "What is the airspeed of an unladen swallow?"; - let request = GenerateContentRequest { - contents: vec![Content { - role: Some(Role::User), - parts: Some(vec![Part::Text(prompt.to_string())]), - }], - system_instruction: Some(Content { - role: None, - parts: Some(vec![Part::Text(system_instruction.to_string())]), - }), - ..Default::default() - }; + let request = GenerateContentRequest::builder() + .add_text_content(Role::User, prompt) + .system_instruction_text(system_instruction) + .build(); let result = gemini .generate_content(&request, "gemini-1.0-pro-002") diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs index 578a729..5a54c63 100644 --- a/examples/text-from-text-streaming.rs +++ b/examples/text-from-text-streaming.rs @@ -16,12 +16,7 @@ async fn main() -> Result<(), Box> { let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; let request = GenerateContentRequest::builder() - .add_content( - Content::builder() - .role(Role::User) - .add_part(Part::Text(prompt.to_string())) - .build(), - ) + .add_text_content(Role::User, prompt) .build(); let queue = gemini.stream_generate_content(&request, "gemini-pro").await; diff --git a/examples/text-from-text.rs b/examples/text-from-text.rs index 9b7cfe5..e62329f 100644 --- a/examples/text-from-text.rs +++ b/examples/text-from-text.rs @@ -16,12 +16,7 @@ async fn main() -> Result<(), Box> { let prompt = "What is the airspeed of an unladen swallow?"; let request = GenerateContentRequest::builder() - .add_content( - Content::builder() - .role(Role::User) - .add_part(Part::Text(prompt.to_string())) - .build(), - ) + .add_text_content(Role::User, prompt) .build(); let response = gemini.generate_content(&request, "gemini-pro").await?; println!("Response: {:?}", response.candidates[0].get_text().unwrap()); diff --git a/src/types/common.rs b/src/types/common.rs index 5afbf6f..8bbdeea 100644 --- a/src/types/common.rs +++ b/src/types/common.rs @@ -37,6 +37,10 @@ impl ContentBuilder { } } + pub fn add_text_part>(self, text: T) -> Self { + self.add_part(Part::Text(text.into())) + } + pub fn add_part(mut self, part: Part) -> Self { match &mut self.content.parts { Some(parts) => parts.push(part), diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index c5204b3..040952b 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use serde_json::Value; -use super::{Content, VertexApiError}; +use super::{Content, Role, VertexApiError}; use crate::error::Result; #[derive(Clone, Default, Serialize, Deserialize)] @@ -37,6 +37,11 @@ impl GenerateContentRequestBuilder { } } + pub fn add_text_content>(self, role: Role, text: T) -> Self { + let content = Content::builder().role(role).add_text_part(text).build(); + self.add_content(content) + } + pub fn add_content(mut self, content: Content) -> Self { self.request.contents.push(content); self @@ -57,6 +62,10 @@ impl GenerateContentRequestBuilder { self } + pub fn system_instruction_text>(self, text: T) -> Self { + self.system_instruction(Content::builder().add_text_part(text).build()) + } + pub fn system_instruction(mut self, system_instruction: Content) -> Self { self.request.system_instruction = Some(system_instruction); self