From 8e9a186ba90ea3291f57318172b9bbd4b96ffb84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Thu, 29 Feb 2024 19:48:42 +0000 Subject: [PATCH] Adds a count_token methods to the Gemini Clent --- examples/count-tokens.rs | 32 ++++++++++++++++++++++++++++++++ src/client.rs | 28 ++++++++++++++++++++++++++-- src/types/count_tokens.rs | 13 ++++++++++--- 3 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 examples/count-tokens.rs diff --git a/examples/count-tokens.rs b/examples/count-tokens.rs new file mode 100644 index 0000000..9169f54 --- /dev/null +++ b/examples/count-tokens.rs @@ -0,0 +1,32 @@ +use std::sync::Arc; + +use gemini_rs::prelude::*; + +use gcp_auth::AuthenticationManager; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let prompt = "What is the airspeed of an unladen swallow?"; + let request = CountTokensRequest { + contents: Content { + role: "user".to_string(), + parts: Some(vec![Part::Text(prompt.to_string())]), + }, + }; + let result = gemini.count_tokens(&request, "gemini-pro").await?; + println!("Response: {:?}", result); + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index 04de237..f77335e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,8 +7,8 @@ use reqwest_eventsource::{Event, EventSource}; use crate::dialogue::{Message, Role}; use crate::error::{Error, Result}; use crate::prelude::{ - Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, - TextEmbeddingRequest, TextEmbeddingResponse, + Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, + GenerateContentResponse, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse, }; use crate::{prelude::Part, token_provider::TokenProvider}; @@ -191,6 +191,30 @@ impl GeminiClient { .send() .await?; let txt_json = resp.text().await?; + tracing::debug!("text_embeddings response: {:?}", txt_json); Ok(serde_json::from_str::(&txt_json)?) } + + pub async fn count_tokens( + &self, + request: &CountTokensRequest, + model: &str, + ) -> Result { + let endpoint_url = format!( + "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:countTokens", + self.api_endpoint, self.project_id, self.location_id, model, + ); + let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; + let resp = self + .client + .post(&endpoint_url) + .bearer_auth(access_token) + .json(&request) + .send() + .await?; + + let txt_json = resp.text().await?; + tracing::debug!("count_tokens response: {:?}", txt_json); + Ok(serde_json::from_str(&txt_json)?) + } } diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs index f3036e5..e294ecc 100644 --- a/src/types/count_tokens.rs +++ b/src/types/count_tokens.rs @@ -8,7 +8,14 @@ pub struct CountTokensRequest { } #[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CountTokensResponse { - pub total_tokens: i32, +#[serde(untagged)] +pub enum CountTokensResponse { + #[serde(rename_all = "camelCase")] + Ok { + total_tokens: i32, + total_billable_characters: u32, + }, + Error { + error: super::Error, + }, }