Adds more builders

This commit is contained in:
2024-11-27 15:11:01 +00:00
parent 326b3919d1
commit db5a01afef
4 changed files with 120 additions and 26 deletions

View File

@@ -15,7 +15,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
);
let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
let request = GenerateContentRequest::from_prompt(prompt, None);
let request = GenerateContentRequest::builder()
.add_content(
Content::builder()
.role(Role::User)
.add_part(Part::Text(prompt.to_string()))
.build(),
)
.build();
let queue = gemini.stream_generate_content(&request, "gemini-pro").await;
while let Some(response) = queue.pop().await {

View File

@@ -15,7 +15,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
);
let prompt = "What is the airspeed of an unladen swallow?";
let request = GenerateContentRequest::builder().with_prompt(prompt).build();
let request = GenerateContentRequest::builder()
.add_content(
Content::builder()
.role(Role::User)
.add_part(Part::Text(prompt.to_string()))
.build(),
)
.build();
let response = gemini.generate_content(&request, "gemini-pro").await?;
println!("Response: {:?}", response.candidates[0].get_text().unwrap());

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, str::FromStr};
use std::{collections::HashMap, str::FromStr, vec};
use serde::{Deserialize, Serialize};
@@ -20,6 +20,39 @@ impl Content {
.collect::<String>()
})
}
pub fn builder() -> ContentBuilder {
ContentBuilder::new()
}
}
pub struct ContentBuilder {
content: Content,
}
impl ContentBuilder {
pub fn new() -> Self {
Self {
content: Default::default(),
}
}
pub fn add_part(mut self, part: Part) -> Self {
match &mut self.content.parts {
Some(parts) => parts.push(part),
None => self.content.parts = Some(vec![part]),
}
self
}
pub fn role(mut self, role: Role) -> Self {
self.content.role = Some(role);
self
}
pub fn build(self) -> Content {
self.content
}
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]

View File

@@ -3,7 +3,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::{Content, Part, Role, VertexApiError};
use super::{Content, VertexApiError};
use crate::error::Result;
#[derive(Clone, Default, Serialize, Deserialize)]
@@ -21,19 +21,6 @@ pub struct GenerateContentRequest {
}
impl GenerateContentRequest {
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
GenerateContentRequest {
contents: vec![Content {
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config,
tools: None,
system_instruction: None,
safety_settings: None,
}
}
pub fn builder() -> GenerateContentRequestBuilder {
GenerateContentRequestBuilder::new()
}
@@ -50,30 +37,27 @@ impl GenerateContentRequestBuilder {
}
}
pub fn with_prompt(mut self, prompt: &str) -> Self {
self.request.contents = vec![Content {
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}];
pub fn add_content(mut self, content: Content) -> Self {
self.request.contents.push(content);
self
}
pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self {
pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self {
self.request.generation_config = Some(generation_config);
self
}
pub fn with_tools(mut self, tools: Vec<Tools>) -> Self {
pub fn tools(mut self, tools: Vec<Tools>) -> Self {
self.request.tools = Some(tools);
self
}
pub fn with_safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
pub fn safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
self.request.safety_settings = Some(safety_settings);
self
}
pub fn with_system_instruction(mut self, system_instruction: Content) -> Self {
pub fn system_instruction(mut self, system_instruction: Content) -> Self {
self.request.system_instruction = Some(system_instruction);
self
}
@@ -119,6 +103,68 @@ pub struct GenerationConfig {
pub response_schema: Option<Value>,
}
impl GenerationConfig {
pub fn builder() -> GenerationConfigBuilder {
GenerationConfigBuilder::new()
}
}
pub struct GenerationConfigBuilder {
generation_config: GenerationConfig,
}
impl GenerationConfigBuilder {
fn new() -> Self {
Self {
generation_config: Default::default(),
}
}
pub fn max_output_tokens<T: Into<i32>>(mut self, max_output_tokens: T) -> Self {
self.generation_config.max_output_tokens = Some(max_output_tokens.into());
self
}
pub fn temperature<T: Into<f32>>(mut self, temperature: T) -> Self {
self.generation_config.temperature = Some(temperature.into());
self
}
pub fn top_p<T: Into<f32>>(mut self, top_p: T) -> Self {
self.generation_config.top_p = Some(top_p.into());
self
}
pub fn top_k<T: Into<i32>>(mut self, top_k: T) -> Self {
self.generation_config.top_k = Some(top_k.into());
self
}
pub fn stop_sequences<T: Into<Vec<String>>>(mut self, stop_sequences: T) -> Self {
self.generation_config.stop_sequences = Some(stop_sequences.into());
self
}
pub fn candidate_count<T: Into<u32>>(mut self, candidate_count: T) -> Self {
self.generation_config.candidate_count = Some(candidate_count.into());
self
}
pub fn response_mime_type<T: Into<String>>(mut self, response_mime_type: T) -> Self {
self.generation_config.response_mime_type = Some(response_mime_type.into());
self
}
pub fn response_schema<T: Into<Value>>(mut self, response_schema: T) -> Self {
self.generation_config.response_schema = Some(response_schema.into());
self
}
pub fn build(self) -> GenerationConfig {
self.generation_config
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {