| | --- |
| | library_name: gemma_torch |
| | license: gemma |
| | license_link: https://ai.google.dev/gemma/terms |
| | pipeline_tag: text-generation |
| | tags: |
| | - pytorch |
| | extra_gated_heading: Access CodeGemma on Hugging Face |
| | extra_gated_prompt: To access CodeGemma on Hugging Face, you’re required to review |
| | and agree to Google’s usage license. To do this, please ensure you’re logged-in |
| | to Hugging Face and click below. Requests are processed immediately. |
| | extra_gated_button_content: Acknowledge license |
| | --- |
| | |
| | # CodeGemma Model Card |
| |
|
| | > [!IMPORTANT] |
| | > |
| | > This repository corresponds to the CodeGemma 7B IT checkpoint for use with [Gemma PyTorch](https://github.com/google/gemma_pytorch). If you're looking for the `transformers` implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b-it. |
| |
|
| | **Model Page**: [CodeGemma](https://ai.google.dev/gemma/docs/codegemma) |
| |
|
| | **Resources and Technical Documentation**: |
| |
|
| | * [Technical Report](https://goo.gle/codegemma) |
| | * [Responsible Generative AI Toolkit](https://ai.google.dev/responsible) |
| |
|
| | **Terms of Use**: [Terms](https://www.kaggle.com/models/google/codegemma/license/consent/verify/huggingface?returnModelRepoId=google/codegemma-7b-it-pytorch) |
| |
|
| | **Authors**: Google |
| |
|
| | # Sample Usage |
| |
|
| | ```python |
| | from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b |
| | from gemma.model import GemmaForCausalLM |
| | from gemma.tokenizer import Tokenizer |
| | import contextlib |
| | import os |
| | import torch |
| | |
| | VARIANT = "7b-it" |
| | MACHINE_TYPE = "cpu" |
| | weights_dir = 'codegemma-7b-it-pytorch' |
| | |
| | @contextlib.contextmanager |
| | def _set_default_tensor_type(dtype: torch.dtype): |
| | """Sets the default torch dtype to the given dtype.""" |
| | torch.set_default_dtype(dtype) |
| | yield |
| | torch.set_default_dtype(torch.float) |
| | |
| | model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b() |
| | model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model") |
| | |
| | device = torch.device(MACHINE_TYPE) |
| | with _set_default_tensor_type(model_config.get_dtype()): |
| | model = GemmaForCausalLM(model_config) |
| | ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt') |
| | model.load_weights(ckpt_path) |
| | model = model.to(device).eval() |
| | |
| | PROMPT = """<start_of_turn>user |
| | Write a Python function to calculate the nth fibonacci number.<end_of_turn> |
| | <start_of_turn>model |
| | """ |
| | |
| | model.generate( |
| | PROMPT, |
| | device=device, |
| | output_len=100, |
| | ) |
| | ``` |