Adds more builders
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use std::{collections::HashMap, str::FromStr};
|
||||
use std::{collections::HashMap, str::FromStr, vec};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -20,6 +20,39 @@ impl Content {
|
||||
.collect::<String>()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn builder() -> ContentBuilder {
|
||||
ContentBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContentBuilder {
|
||||
content: Content,
|
||||
}
|
||||
|
||||
impl ContentBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
content: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_part(mut self, part: Part) -> Self {
|
||||
match &mut self.content.parts {
|
||||
Some(parts) => parts.push(part),
|
||||
None => self.content.parts = Some(vec![part]),
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn role(mut self, role: Role) -> Self {
|
||||
self.content.role = Some(role);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Content {
|
||||
self.content
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::{Content, Part, Role, VertexApiError};
|
||||
use super::{Content, VertexApiError};
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
@@ -21,19 +21,6 @@ pub struct GenerateContentRequest {
|
||||
}
|
||||
|
||||
impl GenerateContentRequest {
|
||||
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
|
||||
GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
generation_config,
|
||||
tools: None,
|
||||
system_instruction: None,
|
||||
safety_settings: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builder() -> GenerateContentRequestBuilder {
|
||||
GenerateContentRequestBuilder::new()
|
||||
}
|
||||
@@ -50,30 +37,27 @@ impl GenerateContentRequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_prompt(mut self, prompt: &str) -> Self {
|
||||
self.request.contents = vec![Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}];
|
||||
pub fn add_content(mut self, content: Content) -> Self {
|
||||
self.request.contents.push(content);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self {
|
||||
pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self {
|
||||
self.request.generation_config = Some(generation_config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tools(mut self, tools: Vec<Tools>) -> Self {
|
||||
pub fn tools(mut self, tools: Vec<Tools>) -> Self {
|
||||
self.request.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
|
||||
pub fn safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
|
||||
self.request.safety_settings = Some(safety_settings);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_system_instruction(mut self, system_instruction: Content) -> Self {
|
||||
pub fn system_instruction(mut self, system_instruction: Content) -> Self {
|
||||
self.request.system_instruction = Some(system_instruction);
|
||||
self
|
||||
}
|
||||
@@ -119,6 +103,68 @@ pub struct GenerationConfig {
|
||||
pub response_schema: Option<Value>,
|
||||
}
|
||||
|
||||
impl GenerationConfig {
|
||||
pub fn builder() -> GenerationConfigBuilder {
|
||||
GenerationConfigBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerationConfigBuilder {
|
||||
generation_config: GenerationConfig,
|
||||
}
|
||||
|
||||
impl GenerationConfigBuilder {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
generation_config: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_output_tokens<T: Into<i32>>(mut self, max_output_tokens: T) -> Self {
|
||||
self.generation_config.max_output_tokens = Some(max_output_tokens.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature<T: Into<f32>>(mut self, temperature: T) -> Self {
|
||||
self.generation_config.temperature = Some(temperature.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p<T: Into<f32>>(mut self, top_p: T) -> Self {
|
||||
self.generation_config.top_p = Some(top_p.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_k<T: Into<i32>>(mut self, top_k: T) -> Self {
|
||||
self.generation_config.top_k = Some(top_k.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop_sequences<T: Into<Vec<String>>>(mut self, stop_sequences: T) -> Self {
|
||||
self.generation_config.stop_sequences = Some(stop_sequences.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn candidate_count<T: Into<u32>>(mut self, candidate_count: T) -> Self {
|
||||
self.generation_config.candidate_count = Some(candidate_count.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn response_mime_type<T: Into<String>>(mut self, response_mime_type: T) -> Self {
|
||||
self.generation_config.response_mime_type = Some(response_mime_type.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn response_schema<T: Into<Value>>(mut self, response_schema: T) -> Self {
|
||||
self.generation_config.response_schema = Some(response_schema.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> GenerationConfig {
|
||||
self.generation_config
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetySetting {
|
||||
|
||||
Reference in New Issue
Block a user