Adds a count_token methods to the Gemini Clent

This commit is contained in:
2024-02-29 19:48:42 +00:00
parent f2da435af8
commit 8e9a186ba9
3 changed files with 68 additions and 5 deletions

32
examples/count-tokens.rs Normal file
View File

@@ -0,0 +1,32 @@
use std::sync::Arc;
use gemini_rs::prelude::*;
use gcp_auth::AuthenticationManager;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let gemini = GeminiClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let prompt = "What is the airspeed of an unladen swallow?";
let request = CountTokensRequest {
contents: Content {
role: "user".to_string(),
parts: Some(vec![Part::Text(prompt.to_string())]),
},
};
let result = gemini.count_tokens(&request, "gemini-pro").await?;
println!("Response: {:?}", result);
Ok(())
}

View File

@@ -7,8 +7,8 @@ use reqwest_eventsource::{Event, EventSource};
use crate::dialogue::{Message, Role}; use crate::dialogue::{Message, Role};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
TextEmbeddingRequest, TextEmbeddingResponse, GenerateContentResponse, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse,
}; };
use crate::{prelude::Part, token_provider::TokenProvider}; use crate::{prelude::Part, token_provider::TokenProvider};
@@ -191,6 +191,30 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
.send() .send()
.await?; .await?;
let txt_json = resp.text().await?; let txt_json = resp.text().await?;
tracing::debug!("text_embeddings response: {:?}", txt_json);
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?) Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
} }
pub async fn count_tokens(
&self,
request: &CountTokensRequest,
model: &str,
) -> Result<CountTokensResponse> {
let endpoint_url = format!(
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:countTokens",
self.api_endpoint, self.project_id, self.location_id, model,
);
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let resp = self
.client
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request)
.send()
.await?;
let txt_json = resp.text().await?;
tracing::debug!("count_tokens response: {:?}", txt_json);
Ok(serde_json::from_str(&txt_json)?)
}
} }

View File

@@ -8,7 +8,14 @@ pub struct CountTokensRequest {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CountTokensResponse {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CountTokensResponse { Ok {
pub total_tokens: i32, total_tokens: i32,
total_billable_characters: u32,
},
Error {
error: super::Error,
},
} }