diff --git a/examples/google-search-retrieval.rs b/examples/google-search-retrieval.rs index d209096..f6832dc 100644 --- a/examples/google-search-retrieval.rs +++ b/examples/google-search-retrieval.rs @@ -29,8 +29,13 @@ async fn main() -> Result<(), Box> { ..Default::default() }; + println!( + "Request: {}", + serde_json::to_string_pretty(&request).unwrap() + ); + let result = gemini - .generate_content(&request, "gemini-1.0-pro-002") + .generate_content(&request, "gemini-1.5-flash-002") .await?; println!("Response: {:?}", result.candidates[0].get_text().unwrap()); diff --git a/examples/google-search.rs b/examples/google-search.rs new file mode 100644 index 0000000..fb38d69 --- /dev/null +++ b/examples/google-search.rs @@ -0,0 +1,44 @@ +use gemini_rs::prelude::*; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().init(); + let authentication_manager = gcp_auth::provider().await?; + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let prompt = "What day is today?"; + + let request = GenerateContentRequest { + contents: vec![Content { + role: Some(Role::User), + parts: Some(vec![Part::Text(prompt.to_string())]), + }], + tools: Some(vec![Tools { + google_search: Some(GoogleSearch::default()), + ..Default::default() + }]), + ..Default::default() + }; + + println!( + "Request: {}", + serde_json::to_string_pretty(&request).unwrap() + ); + + let result = gemini + .generate_content(&request, "gemini-2.0-flash-001") + .await?; + + println!("Response: {:?}", result.candidates[0].get_text().unwrap()); + + Ok(()) +} diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index e4dc719..cfbc92e 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -71,15 +71,39 @@ impl GenerateContentRequestBuilder { pub struct Tools { #[serde(skip_serializing_if = "Option::is_none")] pub function_declarations: Option>, + #[serde(rename = "googleSearchRetrieval")] #[serde(skip_serializing_if = "Option::is_none")] pub google_search_retrieval: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub google_search: Option, +} + +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct GoogleSearch {} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DynamicRetrievalConfig { + pub mode: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub dynamic_threshold: Option, +} + +impl Default for DynamicRetrievalConfig { + fn default() -> Self { + Self { + mode: "MODE_DYNAMIC".to_string(), + dynamic_threshold: Some(0.7), + } + } } #[derive(Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GoogleSearchRetrieval { - pub disable_attribution: bool, + pub dynamic_retrieval_config: DynamicRetrievalConfig, } #[derive(Clone, Debug, Serialize, Deserialize, Default)]