Adds safety setting

This commit is contained in:
2024-08-15 14:44:59 +01:00
parent 60c20958a4
commit cb954ea5db
3 changed files with 111 additions and 0 deletions

View File

@@ -159,6 +159,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
generation_config: None,
tools: None,
system_instruction: None,
safety_settings: None,
};
let response = self.generate_content(&request, model).await?;
@@ -187,6 +188,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
generation_config: generation_config.cloned(),
tools: None,
system_instruction: None,
safety_settings: None,
};
let response = self.generate_content(&request, "gemini-pro").await?;

View File

@@ -6,10 +6,16 @@ use super::{Content, Part, VertexApiError};
use crate::error::Result;
#[derive(Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_config: Option<GenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tools>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<Content>,
}
@@ -23,14 +29,17 @@ impl GenerateContentRequest {
generation_config,
tools: None,
system_instruction: None,
safety_settings: None,
}
}
}
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct Tools {
#[serde(skip_serializing_if = "Option::is_none")]
pub function_declarations: Option<Vec<FunctionDeclaration>>,
#[serde(rename = "googleSearchRetrieval")]
#[serde(skip_serializing_if = "Option::is_none")]
pub google_search_retrieval: Option<GoogleSearchRetrieval>,
}
@@ -43,21 +52,79 @@ pub struct GoogleSearchRetrieval {
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidate_count: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
pub threshold: HarmBlockThreshold,
#[serde(skip_serializing_if = "Option::is_none")]
pub method: Option<HarmBlockMethod>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmCategory {
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
Unspecified,
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
HateSpeech,
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
DangerousContent,
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
Harassment,
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
SexuallyExplicit,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
#[serde(rename = "BLOCK_LOW_AND_ABOVE")]
BlockLowAndAbove,
#[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
BlockMediumAndAbove,
#[serde(rename = "BLOCK_ONLY_HIGH")]
BlockOnlyHigh,
#[serde(rename = "BLOCK_NONE")]
BlockNone,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmBlockMethod {
#[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")]
Unspecified, // HARM_BLOCK_METHOD_UNSPECIFIED
#[serde(rename = "SEVERITY")]
Severity, // SEVERITY
#[serde(rename = "PROBABILITY")]
Probability, // PROBABILITY
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub citation_metadata: Option<CitationMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_ratings: Option<Vec<SafetyRating>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}